forked from jshiffer/go-xmpp
		
	Added callback to process errors after connection.
Added tests and refactored a bit.
This commit is contained in:
		
							
								
								
									
										95
									
								
								_examples/xmpp_chat_client/xmpp_chat_client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								_examples/xmpp_chat_client/xmpp_chat_client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,95 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
xmpp_chat_client is a demo client that connect on an XMPP server to chat with other members
 | 
			
		||||
Note that this example sends to a very specific user. User logic is not implemented here.
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	. "bufio"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	"gosrc.io/xmpp"
 | 
			
		||||
	"gosrc.io/xmpp/stanza"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	currentUserAddress = "localhost:5222"
 | 
			
		||||
	currentUserJid     = "testuser@localhost"
 | 
			
		||||
	currentUserPass    = "testpass"
 | 
			
		||||
	correspondantJid   = "testuser2@localhost"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	config := xmpp.Config{
 | 
			
		||||
		TransportConfiguration: xmpp.TransportConfiguration{
 | 
			
		||||
			Address: currentUserAddress,
 | 
			
		||||
		},
 | 
			
		||||
		Jid:        currentUserJid,
 | 
			
		||||
		Credential: xmpp.Password(currentUserPass),
 | 
			
		||||
		Insecure:   true}
 | 
			
		||||
 | 
			
		||||
	var client *xmpp.Client
 | 
			
		||||
	var err error
 | 
			
		||||
	router := xmpp.NewRouter()
 | 
			
		||||
	router.HandleFunc("message", handleMessage)
 | 
			
		||||
	if client, err = xmpp.NewClient(config, router, errorHandler); err != nil {
 | 
			
		||||
		fmt.Println("Error new client")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Connecting client and handling messages
 | 
			
		||||
	// To use a stream manager, just write something like this instead :
 | 
			
		||||
	//cm := xmpp.NewStreamManager(client, startMessaging)
 | 
			
		||||
	//log.Fatal(cm.Run()) //=> this will lock the calling goroutine
 | 
			
		||||
 | 
			
		||||
	if err = client.Connect(); err != nil {
 | 
			
		||||
		fmt.Printf("XMPP connection failed: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	startMessaging(client)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func startMessaging(client xmpp.Sender) {
 | 
			
		||||
	reader := NewReader(os.Stdin)
 | 
			
		||||
	textChan := make(chan string)
 | 
			
		||||
	var text string
 | 
			
		||||
	for {
 | 
			
		||||
		fmt.Print("Enter text: ")
 | 
			
		||||
		go readInput(reader, textChan)
 | 
			
		||||
		select {
 | 
			
		||||
		case <-killChan:
 | 
			
		||||
			return
 | 
			
		||||
		case text = <-textChan:
 | 
			
		||||
			reply := stanza.Message{Attrs: stanza.Attrs{To: correspondantJid}, Body: text}
 | 
			
		||||
			err := client.Send(reply)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				fmt.Printf("There was a problem sending the message : %v", reply)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func readInput(reader *Reader, textChan chan string) {
 | 
			
		||||
	text, _ := reader.ReadString('\n')
 | 
			
		||||
	textChan <- text
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var killChan = make(chan struct{})
 | 
			
		||||
 | 
			
		||||
// If an error occurs, this is used
 | 
			
		||||
func errorHandler(err error) {
 | 
			
		||||
	fmt.Printf("%v", err)
 | 
			
		||||
	killChan <- struct{}{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleMessage(s xmpp.Sender, p stanza.Packet) {
 | 
			
		||||
	msg, ok := p.(stanza.Message)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		_, _ = fmt.Fprintf(os.Stdout, "Ignoring packet: %T\n", p)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	_, _ = fmt.Fprintf(os.Stdout, "Body = %s - from = %s\n", msg.Body, msg.From)
 | 
			
		||||
}
 | 
			
		||||
@@ -35,7 +35,7 @@ func main() {
 | 
			
		||||
		IQNamespaces("jabber:iq:version").
 | 
			
		||||
		HandlerFunc(handleVersion)
 | 
			
		||||
 | 
			
		||||
	component, err := xmpp.NewComponent(opts, router)
 | 
			
		||||
	component, err := xmpp.NewComponent(opts, router, handleError)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -47,6 +47,10 @@ func main() {
 | 
			
		||||
	log.Fatal(cm.Run())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleError(err error) {
 | 
			
		||||
	fmt.Println(err.Error())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleMessage(_ xmpp.Sender, p stanza.Packet) {
 | 
			
		||||
	msg, ok := p.(stanza.Message)
 | 
			
		||||
	if !ok {
 | 
			
		||||
 
 | 
			
		||||
@@ -53,7 +53,7 @@ func main() {
 | 
			
		||||
			handleIQ(s, p, player)
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
	client, err := xmpp.NewClient(config, router)
 | 
			
		||||
	client, err := xmpp.NewClient(config, router, errorHandler)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -61,6 +61,9 @@ func main() {
 | 
			
		||||
	cm := xmpp.NewStreamManager(client, nil)
 | 
			
		||||
	log.Fatal(cm.Run())
 | 
			
		||||
}
 | 
			
		||||
func errorHandler(err error) {
 | 
			
		||||
	fmt.Println(err.Error())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleMessage(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) {
 | 
			
		||||
	msg, ok := p.(stanza.Message)
 | 
			
		||||
 
 | 
			
		||||
@@ -28,7 +28,7 @@ func main() {
 | 
			
		||||
	router := xmpp.NewRouter()
 | 
			
		||||
	router.HandleFunc("message", handleMessage)
 | 
			
		||||
 | 
			
		||||
	client, err := xmpp.NewClient(config, router)
 | 
			
		||||
	client, err := xmpp.NewClient(config, router, errorHandler)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -39,6 +39,10 @@ func main() {
 | 
			
		||||
	log.Fatal(cm.Run())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func errorHandler(err error) {
 | 
			
		||||
	fmt.Println(err.Error())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleMessage(s xmpp.Sender, p stanza.Packet) {
 | 
			
		||||
	msg, ok := p.(stanza.Message)
 | 
			
		||||
	if !ok {
 | 
			
		||||
 
 | 
			
		||||
@@ -26,7 +26,7 @@ func main() {
 | 
			
		||||
	router := xmpp.NewRouter()
 | 
			
		||||
	router.HandleFunc("message", handleMessage)
 | 
			
		||||
 | 
			
		||||
	client, err := xmpp.NewClient(config, router)
 | 
			
		||||
	client, err := xmpp.NewClient(config, router, errorHandler)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -37,6 +37,10 @@ func main() {
 | 
			
		||||
	log.Fatal(cm.Run())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func errorHandler(err error) {
 | 
			
		||||
	fmt.Println(err.Error())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleMessage(s xmpp.Sender, p stanza.Packet) {
 | 
			
		||||
	msg, ok := p.(stanza.Message)
 | 
			
		||||
	if !ok {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										18
									
								
								client.go
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								client.go
									
									
									
									
									
								
							@@ -98,6 +98,8 @@ type Client struct {
 | 
			
		||||
	router *Router
 | 
			
		||||
	// Track and broadcast connection state
 | 
			
		||||
	EventManager
 | 
			
		||||
	// Handle errors from client execution
 | 
			
		||||
	ErrorHandler func(error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
@@ -107,7 +109,7 @@ Setting up the client / Checking the parameters
 | 
			
		||||
// NewClient generates a new XMPP client, based on Config passed as parameters.
 | 
			
		||||
// If host is not specified, the DNS SRV should be used to find the host from the domainpart of the JID.
 | 
			
		||||
// Default the port to 5222.
 | 
			
		||||
func NewClient(config Config, r *Router) (c *Client, err error) {
 | 
			
		||||
func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, err error) {
 | 
			
		||||
	if config.KeepaliveInterval == 0 {
 | 
			
		||||
		config.KeepaliveInterval = time.Second * 30
 | 
			
		||||
	}
 | 
			
		||||
@@ -143,6 +145,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
 | 
			
		||||
	c = new(Client)
 | 
			
		||||
	c.config = config
 | 
			
		||||
	c.router = r
 | 
			
		||||
	c.ErrorHandler = errorHandler
 | 
			
		||||
 | 
			
		||||
	if c.config.ConnectTimeout == 0 {
 | 
			
		||||
		c.config.ConnectTimeout = 15 // 15 second as default
 | 
			
		||||
@@ -191,10 +194,7 @@ func (c *Client) Resume(state SMState) error {
 | 
			
		||||
	go keepalive(c.transport, c.config.KeepaliveInterval, keepaliveQuit)
 | 
			
		||||
	// Start the receiver go routine
 | 
			
		||||
	state = c.Session.SMState
 | 
			
		||||
	// Leaving this channel here for later. Not used atm. We should return this instead of an error because right
 | 
			
		||||
	// now the returned error is lost in limbo.
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	go c.recv(state, keepaliveQuit, errChan)
 | 
			
		||||
	go c.recv(state, keepaliveQuit)
 | 
			
		||||
 | 
			
		||||
	// We're connected and can now receive and send messages.
 | 
			
		||||
	//fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online")
 | 
			
		||||
@@ -273,11 +273,11 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
 | 
			
		||||
// Go routines
 | 
			
		||||
 | 
			
		||||
// Loop: Receive data from server
 | 
			
		||||
func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan<- error) {
 | 
			
		||||
func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) {
 | 
			
		||||
	for {
 | 
			
		||||
		val, err := stanza.NextPacket(c.transport.GetDecoder())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			errChan <- err
 | 
			
		||||
			c.ErrorHandler(err)
 | 
			
		||||
			close(keepaliveQuit)
 | 
			
		||||
			c.disconnected(state)
 | 
			
		||||
			return
 | 
			
		||||
@@ -289,7 +289,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan
 | 
			
		||||
			c.router.route(c, val)
 | 
			
		||||
			close(keepaliveQuit)
 | 
			
		||||
			c.streamError(packet.Error.Local, packet.Text)
 | 
			
		||||
			errChan <- errors.New("stream error: " + packet.Error.Local)
 | 
			
		||||
			c.ErrorHandler(errors.New("stream error: " + packet.Error.Local))
 | 
			
		||||
			return
 | 
			
		||||
		// Process Stream management nonzas
 | 
			
		||||
		case stanza.SMRequest:
 | 
			
		||||
@@ -299,7 +299,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan
 | 
			
		||||
			}, H: state.Inbound}
 | 
			
		||||
			err = c.Send(answer)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				errChan <- err
 | 
			
		||||
				c.ErrorHandler(err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										374
									
								
								client_test.go
									
									
									
									
									
								
							
							
						
						
									
										374
									
								
								client_test.go
									
									
									
									
									
								
							@@ -1,6 +1,7 @@
 | 
			
		||||
package xmpp
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/xml"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -14,9 +15,8 @@ import (
 | 
			
		||||
const (
 | 
			
		||||
	// Default port is not standard XMPP port to avoid interfering
 | 
			
		||||
	// with local running XMPP server
 | 
			
		||||
	testXMPPAddress = "localhost:15222"
 | 
			
		||||
 | 
			
		||||
	defaultTimeout = 2 * time.Second
 | 
			
		||||
	testXMPPAddress  = "localhost:15222"
 | 
			
		||||
	testClientDomain = "localhost"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestEventManager(t *testing.T) {
 | 
			
		||||
@@ -40,7 +40,7 @@ func TestEventManager(t *testing.T) {
 | 
			
		||||
func TestClient_Connect(t *testing.T) {
 | 
			
		||||
	// Setup Mock server
 | 
			
		||||
	mock := ServerMock{}
 | 
			
		||||
	mock.Start(t, testXMPPAddress, handlerConnectSuccess)
 | 
			
		||||
	mock.Start(t, testXMPPAddress, handlerClientConnectSuccess)
 | 
			
		||||
 | 
			
		||||
	// Test / Check result
 | 
			
		||||
	config := Config{
 | 
			
		||||
@@ -54,7 +54,7 @@ func TestClient_Connect(t *testing.T) {
 | 
			
		||||
	var client *Client
 | 
			
		||||
	var err error
 | 
			
		||||
	router := NewRouter()
 | 
			
		||||
	if client, err = NewClient(config, router); err != nil {
 | 
			
		||||
	if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
 | 
			
		||||
		t.Errorf("connect create XMPP client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -82,7 +82,7 @@ func TestClient_NoInsecure(t *testing.T) {
 | 
			
		||||
	var client *Client
 | 
			
		||||
	var err error
 | 
			
		||||
	router := NewRouter()
 | 
			
		||||
	if client, err = NewClient(config, router); err != nil {
 | 
			
		||||
	if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
 | 
			
		||||
		t.Errorf("cannot create XMPP client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -112,7 +112,7 @@ func TestClient_FeaturesTracking(t *testing.T) {
 | 
			
		||||
	var client *Client
 | 
			
		||||
	var err error
 | 
			
		||||
	router := NewRouter()
 | 
			
		||||
	if client, err = NewClient(config, router); err != nil {
 | 
			
		||||
	if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
 | 
			
		||||
		t.Errorf("cannot create XMPP client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -127,7 +127,7 @@ func TestClient_FeaturesTracking(t *testing.T) {
 | 
			
		||||
func TestClient_RFC3921Session(t *testing.T) {
 | 
			
		||||
	// Setup Mock server
 | 
			
		||||
	mock := ServerMock{}
 | 
			
		||||
	mock.Start(t, testXMPPAddress, handlerConnectWithSession)
 | 
			
		||||
	mock.Start(t, testXMPPAddress, handlerClientConnectWithSession)
 | 
			
		||||
 | 
			
		||||
	// Test / Check result
 | 
			
		||||
	config := Config{
 | 
			
		||||
@@ -142,7 +142,7 @@ func TestClient_RFC3921Session(t *testing.T) {
 | 
			
		||||
	var client *Client
 | 
			
		||||
	var err error
 | 
			
		||||
	router := NewRouter()
 | 
			
		||||
	if client, err = NewClient(config, router); err != nil {
 | 
			
		||||
	if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
 | 
			
		||||
		t.Errorf("connect create XMPP client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -153,48 +153,254 @@ func TestClient_RFC3921Session(t *testing.T) {
 | 
			
		||||
	mock.Stop()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Testing sending an IQ to the mock server and reading its response.
 | 
			
		||||
func TestClient_SendIQ(t *testing.T) {
 | 
			
		||||
	done := make(chan struct{})
 | 
			
		||||
	// Handler for Mock server
 | 
			
		||||
	h := func(t *testing.T, c net.Conn) {
 | 
			
		||||
		handlerClientConnectSuccess(t, c)
 | 
			
		||||
		discardPresence(t, c)
 | 
			
		||||
		respondToIQ(t, c)
 | 
			
		||||
		done <- struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	client, mock := mockClientConnection(t, h, testClientIqPort)
 | 
			
		||||
 | 
			
		||||
	ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
 | 
			
		||||
	iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
 | 
			
		||||
	disco := iqReq.DiscoInfo()
 | 
			
		||||
	iqReq.Payload = disco
 | 
			
		||||
 | 
			
		||||
	// Handle a possible error
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	errorHandler := func(err error) {
 | 
			
		||||
		errChan <- err
 | 
			
		||||
	}
 | 
			
		||||
	client.ErrorHandler = errorHandler
 | 
			
		||||
	res, err := client.SendIQ(ctx, iqReq)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf(err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-res: // If the server responds with an IQ, we pass the test
 | 
			
		||||
	case err := <-errChan: // If the server sends an error, or there is a connection error
 | 
			
		||||
		t.Errorf(err.Error())
 | 
			
		||||
	case <-time.After(defaultChannelTimeout): // If we timeout
 | 
			
		||||
		t.Errorf("Failed to receive response, to sent IQ, from mock server")
 | 
			
		||||
	}
 | 
			
		||||
	select {
 | 
			
		||||
	case <-done:
 | 
			
		||||
		mock.Stop()
 | 
			
		||||
	case <-time.After(defaultChannelTimeout):
 | 
			
		||||
		t.Errorf("The mock server failed to finish its job !")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestClient_SendIQFail(t *testing.T) {
 | 
			
		||||
	done := make(chan struct{})
 | 
			
		||||
	// Handler for Mock server
 | 
			
		||||
	h := func(t *testing.T, c net.Conn) {
 | 
			
		||||
		handlerClientConnectSuccess(t, c)
 | 
			
		||||
		discardPresence(t, c)
 | 
			
		||||
		respondToIQ(t, c)
 | 
			
		||||
		done <- struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	client, mock := mockClientConnection(t, h, testClientIqFailPort)
 | 
			
		||||
 | 
			
		||||
	//==================
 | 
			
		||||
	// Create an IQ to send
 | 
			
		||||
	ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
 | 
			
		||||
	iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
 | 
			
		||||
	disco := iqReq.DiscoInfo()
 | 
			
		||||
	iqReq.Payload = disco
 | 
			
		||||
	// Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified
 | 
			
		||||
	// so we need to overwrite it.
 | 
			
		||||
	iqReq.Id = ""
 | 
			
		||||
 | 
			
		||||
	// Handle a possible error
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	errorHandler := func(err error) {
 | 
			
		||||
		errChan <- err
 | 
			
		||||
	}
 | 
			
		||||
	client.ErrorHandler = errorHandler
 | 
			
		||||
	res, _ := client.SendIQ(ctx, iqReq)
 | 
			
		||||
 | 
			
		||||
	// Test
 | 
			
		||||
	select {
 | 
			
		||||
	case <-res: // If the server responds with an IQ
 | 
			
		||||
		t.Errorf("Server should not respond with an IQ since the request is expected to be invalid !")
 | 
			
		||||
	case <-errChan: // If the server sends an error, the test passes
 | 
			
		||||
	case <-time.After(defaultChannelTimeout): // If we timeout
 | 
			
		||||
		t.Errorf("Failed to receive response, to sent IQ, from mock server")
 | 
			
		||||
	}
 | 
			
		||||
	select {
 | 
			
		||||
	case <-done:
 | 
			
		||||
		mock.Stop()
 | 
			
		||||
	case <-time.After(defaultChannelTimeout):
 | 
			
		||||
		t.Errorf("The mock server failed to finish its job !")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestClient_SendRaw(t *testing.T) {
 | 
			
		||||
	done := make(chan struct{})
 | 
			
		||||
	// Handler for Mock server
 | 
			
		||||
	h := func(t *testing.T, c net.Conn) {
 | 
			
		||||
		handlerClientConnectSuccess(t, c)
 | 
			
		||||
		discardPresence(t, c)
 | 
			
		||||
		respondToIQ(t, c)
 | 
			
		||||
		done <- struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	type testCase struct {
 | 
			
		||||
		req       string
 | 
			
		||||
		shouldErr bool
 | 
			
		||||
		port      int
 | 
			
		||||
	}
 | 
			
		||||
	testRequests := make(map[string]testCase)
 | 
			
		||||
	// Sending a correct IQ of type get. Not supposed to err
 | 
			
		||||
	testRequests["Correct IQ"] = testCase{
 | 
			
		||||
		req:       `<iq type="get" id="91bd0bba-012f-4d92-bb17-5fc41e6fe545" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
 | 
			
		||||
		shouldErr: false,
 | 
			
		||||
		port:      testClientRawPort + 100,
 | 
			
		||||
	}
 | 
			
		||||
	// Sending an IQ with a missing ID. Should err
 | 
			
		||||
	testRequests["IQ with missing ID"] = testCase{
 | 
			
		||||
		req:       `<iq type="get" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
 | 
			
		||||
		shouldErr: true,
 | 
			
		||||
		port:      testClientRawPort,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// A handler for the client.
 | 
			
		||||
	// In the failing test, the server returns a stream error, which triggers this handler, client side.
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	errHandler := func(err error) {
 | 
			
		||||
		errChan <- err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Tests for all the IQs
 | 
			
		||||
	for name, tcase := range testRequests {
 | 
			
		||||
		t.Run(name, func(st *testing.T) {
 | 
			
		||||
			//Connecting to a mock server, initialized with given port and handler function
 | 
			
		||||
			c, m := mockClientConnection(t, h, tcase.port)
 | 
			
		||||
			c.ErrorHandler = errHandler
 | 
			
		||||
			// Sending raw xml from test case
 | 
			
		||||
			err := c.SendRaw(tcase.req)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Errorf("Error sending Raw string")
 | 
			
		||||
			}
 | 
			
		||||
			// Just wait a little so the message has time to arrive
 | 
			
		||||
			select {
 | 
			
		||||
			// We don't use the default "long" timeout here because waiting it out means passing the test.
 | 
			
		||||
			case <-time.After(100 * time.Millisecond):
 | 
			
		||||
			case err = <-errChan:
 | 
			
		||||
				if err == nil && tcase.shouldErr {
 | 
			
		||||
					t.Errorf("Failed to get closing stream err")
 | 
			
		||||
				} else if err != nil && !tcase.shouldErr {
 | 
			
		||||
					t.Errorf("This test is not supposed to err !")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			c.transport.Close()
 | 
			
		||||
			select {
 | 
			
		||||
			case <-done:
 | 
			
		||||
				m.Stop()
 | 
			
		||||
			case <-time.After(defaultChannelTimeout):
 | 
			
		||||
				t.Errorf("The mock server failed to finish its job !")
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestClient_Disconnect(t *testing.T) {
 | 
			
		||||
	c, m := mockClientConnection(t, handlerClientConnectSuccess, testClientBasePort)
 | 
			
		||||
	err := c.transport.Ping()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Could not ping but not disconnected yet")
 | 
			
		||||
	}
 | 
			
		||||
	c.Disconnect()
 | 
			
		||||
	err = c.transport.Ping()
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Errorf("Did not disconnect properly")
 | 
			
		||||
	}
 | 
			
		||||
	m.Stop()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestClient_DisconnectStreamManager(t *testing.T) {
 | 
			
		||||
	// Init mock server
 | 
			
		||||
	// Setup Mock server
 | 
			
		||||
	mock := ServerMock{}
 | 
			
		||||
	mock.Start(t, testXMPPAddress, handlerAbortTLS)
 | 
			
		||||
 | 
			
		||||
	// Test / Check result
 | 
			
		||||
	config := Config{
 | 
			
		||||
		TransportConfiguration: TransportConfiguration{
 | 
			
		||||
			Address: testXMPPAddress,
 | 
			
		||||
		},
 | 
			
		||||
		Jid:        "test@localhost",
 | 
			
		||||
		Credential: Password("test"),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var client *Client
 | 
			
		||||
	var err error
 | 
			
		||||
	router := NewRouter()
 | 
			
		||||
	if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
 | 
			
		||||
		t.Errorf("cannot create XMPP client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sman := NewStreamManager(client, nil)
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	runSMan := func(errChan chan error) {
 | 
			
		||||
		errChan <- sman.Run()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go runSMan(errChan)
 | 
			
		||||
	select {
 | 
			
		||||
	case <-errChan:
 | 
			
		||||
	case <-time.After(defaultChannelTimeout):
 | 
			
		||||
		// When insecure is not allowed:
 | 
			
		||||
		t.Errorf("should fail as insecure connection is not allowed and server does not support TLS")
 | 
			
		||||
	}
 | 
			
		||||
	mock.Stop()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//=============================================================================
 | 
			
		||||
// Basic XMPP Server Mock Handlers.
 | 
			
		||||
 | 
			
		||||
const serverStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' id='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
 | 
			
		||||
 | 
			
		||||
// Test connection with a basic straightforward workflow
 | 
			
		||||
func handlerConnectSuccess(t *testing.T, c net.Conn) {
 | 
			
		||||
func handlerClientConnectSuccess(t *testing.T, c net.Conn) {
 | 
			
		||||
	decoder := xml.NewDecoder(c)
 | 
			
		||||
	checkOpenStream(t, c, decoder)
 | 
			
		||||
	checkClientOpenStream(t, c, decoder)
 | 
			
		||||
 | 
			
		||||
	sendStreamFeatures(t, c, decoder) // Send initial features
 | 
			
		||||
	readAuth(t, decoder)
 | 
			
		||||
	fmt.Fprintln(c, "<success xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"/>")
 | 
			
		||||
 | 
			
		||||
	checkOpenStream(t, c, decoder) // Reset stream
 | 
			
		||||
	sendBindFeature(t, c, decoder) // Send post auth features
 | 
			
		||||
	checkClientOpenStream(t, c, decoder) // Reset stream
 | 
			
		||||
	sendBindFeature(t, c, decoder)       // Send post auth features
 | 
			
		||||
	bind(t, c, decoder)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// We expect client will abort on TLS
 | 
			
		||||
func handlerAbortTLS(t *testing.T, c net.Conn) {
 | 
			
		||||
	decoder := xml.NewDecoder(c)
 | 
			
		||||
	checkOpenStream(t, c, decoder)
 | 
			
		||||
	checkClientOpenStream(t, c, decoder)
 | 
			
		||||
	sendStreamFeatures(t, c, decoder) // Send initial features
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Test connection with mandatory session (RFC-3921)
 | 
			
		||||
func handlerConnectWithSession(t *testing.T, c net.Conn) {
 | 
			
		||||
func handlerClientConnectWithSession(t *testing.T, c net.Conn) {
 | 
			
		||||
	decoder := xml.NewDecoder(c)
 | 
			
		||||
	checkOpenStream(t, c, decoder)
 | 
			
		||||
	checkClientOpenStream(t, c, decoder)
 | 
			
		||||
 | 
			
		||||
	sendStreamFeatures(t, c, decoder) // Send initial features
 | 
			
		||||
	readAuth(t, decoder)
 | 
			
		||||
	fmt.Fprintln(c, "<success xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"/>")
 | 
			
		||||
 | 
			
		||||
	checkOpenStream(t, c, decoder)    // Reset stream
 | 
			
		||||
	sendRFC3921Feature(t, c, decoder) // Send post auth features
 | 
			
		||||
	checkClientOpenStream(t, c, decoder) // Reset stream
 | 
			
		||||
	sendRFC3921Feature(t, c, decoder)    // Send post auth features
 | 
			
		||||
	bind(t, c, decoder)
 | 
			
		||||
	session(t, c, decoder)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) {
 | 
			
		||||
func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) {
 | 
			
		||||
	c.SetDeadline(time.Now().Add(defaultTimeout))
 | 
			
		||||
	defer c.SetDeadline(time.Time{})
 | 
			
		||||
 | 
			
		||||
@@ -220,105 +426,35 @@ func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) {
 | 
			
		||||
	// This is a basic server, supporting only 1 stream feature: SASL Plain Auth
 | 
			
		||||
	features := `<stream:features>
 | 
			
		||||
  <mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl">
 | 
			
		||||
    <mechanism>PLAIN</mechanism>
 | 
			
		||||
  </mechanisms>
 | 
			
		||||
</stream:features>`
 | 
			
		||||
	if _, err := fmt.Fprintln(c, features); err != nil {
 | 
			
		||||
		t.Errorf("cannot send stream feature: %s", err)
 | 
			
		||||
func mockClientConnection(t *testing.T, serverHandler func(*testing.T, net.Conn), port int) (*Client, ServerMock) {
 | 
			
		||||
	mock := ServerMock{}
 | 
			
		||||
	testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port)
 | 
			
		||||
 | 
			
		||||
	mock.Start(t, testServerAddress, serverHandler)
 | 
			
		||||
 | 
			
		||||
	config := Config{
 | 
			
		||||
		TransportConfiguration: TransportConfiguration{
 | 
			
		||||
			Address: testServerAddress,
 | 
			
		||||
		},
 | 
			
		||||
		Jid:        "test@localhost",
 | 
			
		||||
		Credential: Password("test"),
 | 
			
		||||
		Insecure:   true}
 | 
			
		||||
 | 
			
		||||
	var client *Client
 | 
			
		||||
	var err error
 | 
			
		||||
	router := NewRouter()
 | 
			
		||||
	if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
 | 
			
		||||
		t.Errorf("connect create XMPP client: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = client.Connect(); err != nil {
 | 
			
		||||
		t.Errorf("XMPP connection failed: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return client, mock
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO return err in case of error reading the auth params
 | 
			
		||||
func readAuth(t *testing.T, decoder *xml.Decoder) string {
 | 
			
		||||
	se, err := stanza.NextStart(decoder)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("cannot read auth: %s", err)
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var nv interface{}
 | 
			
		||||
	nv = &stanza.SASLAuth{}
 | 
			
		||||
	// Decode element into pointer storage
 | 
			
		||||
	if err = decoder.DecodeElement(nv, &se); err != nil {
 | 
			
		||||
		t.Errorf("cannot decode auth: %s", err)
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch v := nv.(type) {
 | 
			
		||||
	case *stanza.SASLAuth:
 | 
			
		||||
		return v.Value
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) {
 | 
			
		||||
	// This is a basic server, supporting only 1 stream feature after auth: resource binding
 | 
			
		||||
	features := `<stream:features>
 | 
			
		||||
  <bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>
 | 
			
		||||
</stream:features>`
 | 
			
		||||
	if _, err := fmt.Fprintln(c, features); err != nil {
 | 
			
		||||
		t.Errorf("cannot send stream feature: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) {
 | 
			
		||||
	// This is a basic server, supporting only 2 features after auth: resource & session binding
 | 
			
		||||
	features := `<stream:features>
 | 
			
		||||
  <bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>
 | 
			
		||||
  <session xmlns='urn:ietf:params:xml:ns:xmpp-session'/>
 | 
			
		||||
</stream:features>`
 | 
			
		||||
	if _, err := fmt.Fprintln(c, features); err != nil {
 | 
			
		||||
		t.Errorf("cannot send stream feature: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) {
 | 
			
		||||
	se, err := stanza.NextStart(decoder)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("cannot read bind: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	iq := &stanza.IQ{}
 | 
			
		||||
	// Decode element into pointer storage
 | 
			
		||||
	if err = decoder.DecodeElement(&iq, &se); err != nil {
 | 
			
		||||
		t.Errorf("cannot decode bind iq: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO Check all elements
 | 
			
		||||
	switch iq.Payload.(type) {
 | 
			
		||||
	case *stanza.Bind:
 | 
			
		||||
		result := `<iq id='%s' type='result'>
 | 
			
		||||
  <bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'>
 | 
			
		||||
  	<jid>%s</jid>
 | 
			
		||||
  </bind>
 | 
			
		||||
</iq>`
 | 
			
		||||
		fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func session(t *testing.T, c net.Conn, decoder *xml.Decoder) {
 | 
			
		||||
	se, err := stanza.NextStart(decoder)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("cannot read session: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	iq := &stanza.IQ{}
 | 
			
		||||
	// Decode element into pointer storage
 | 
			
		||||
	if err = decoder.DecodeElement(&iq, &se); err != nil {
 | 
			
		||||
		t.Errorf("cannot decode session iq: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch iq.Payload.(type) {
 | 
			
		||||
	case *stanza.StreamSession:
 | 
			
		||||
		result := `<iq id='%s' type='result'/>`
 | 
			
		||||
		fmt.Fprintf(c, result, iq.Id)
 | 
			
		||||
	}
 | 
			
		||||
// This really should not be used as is.
 | 
			
		||||
// It's just meant to be a placeholder when error handling is not needed at this level
 | 
			
		||||
func clientDefaultErrorHandler(err error) {
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										20
									
								
								component.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								component.go
									
									
									
									
									
								
							@@ -48,11 +48,12 @@ type Component struct {
 | 
			
		||||
	transport Transport
 | 
			
		||||
 | 
			
		||||
	// read / write
 | 
			
		||||
	socketProxy io.ReadWriter // TODO
 | 
			
		||||
	socketProxy  io.ReadWriter // TODO
 | 
			
		||||
	ErrorHandler func(error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewComponent(opts ComponentOptions, r *Router) (*Component, error) {
 | 
			
		||||
	c := Component{ComponentOptions: opts, router: r}
 | 
			
		||||
func NewComponent(opts ComponentOptions, r *Router, errorHandler func(error)) (*Component, error) {
 | 
			
		||||
	c := Component{ComponentOptions: opts, router: r, ErrorHandler: errorHandler}
 | 
			
		||||
	return &c, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -104,11 +105,8 @@ func (c *Component) Resume(sm SMState) error {
 | 
			
		||||
	case stanza.Handshake:
 | 
			
		||||
		// Start the receiver go routine
 | 
			
		||||
		c.updateState(StateSessionEstablished)
 | 
			
		||||
		// Leaving this channel here for later. Not used atm. We should return this instead of an error because right
 | 
			
		||||
		// now the returned error is lost in limbo.
 | 
			
		||||
		errChan := make(chan error)
 | 
			
		||||
		go c.recv(errChan) // Sends to errChan
 | 
			
		||||
		return err         // Should be empty at this point
 | 
			
		||||
		go c.recv()
 | 
			
		||||
		return err // Should be empty at this point
 | 
			
		||||
	default:
 | 
			
		||||
		c.updateState(StatePermanentError)
 | 
			
		||||
		return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true)
 | 
			
		||||
@@ -128,13 +126,13 @@ func (c *Component) SetHandler(handler EventHandler) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Receiver Go routine receiver
 | 
			
		||||
func (c *Component) recv(errChan chan<- error) {
 | 
			
		||||
func (c *Component) recv() {
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		val, err := stanza.NextPacket(c.transport.GetDecoder())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.updateState(StateDisconnected)
 | 
			
		||||
			errChan <- err
 | 
			
		||||
			c.ErrorHandler(err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// Handle stream errors
 | 
			
		||||
@@ -142,7 +140,7 @@ func (c *Component) recv(errChan chan<- error) {
 | 
			
		||||
		case stanza.StreamError:
 | 
			
		||||
			c.router.route(c, val)
 | 
			
		||||
			c.streamError(p.Error.Local, p.Text)
 | 
			
		||||
			errChan <- errors.New("stream error: " + p.Error.Local)
 | 
			
		||||
			c.ErrorHandler(errors.New("stream error: " + p.Error.Local))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.router.route(c, val)
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"encoding/xml"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"gosrc.io/xmpp/stanza"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -15,19 +16,7 @@ import (
 | 
			
		||||
// Tests are ran in parallel, so each test creating a server must use a different port so we do not get any
 | 
			
		||||
// conflict. Using iota for this should do the trick.
 | 
			
		||||
const (
 | 
			
		||||
	testComponentDomain  = "localhost"
 | 
			
		||||
	defaultServerName    = "testServer"
 | 
			
		||||
	defaultStreamID      = "91bd0bba-012f-4d92-bb17-5fc41e6fe545"
 | 
			
		||||
	defaultComponentName = "Test Component"
 | 
			
		||||
 | 
			
		||||
	// Default port is not standard XMPP port to avoid interfering
 | 
			
		||||
	// with local running XMPP server
 | 
			
		||||
	testHandshakePort = iota + 15222
 | 
			
		||||
	testDecoderPort
 | 
			
		||||
	testSendIqPort
 | 
			
		||||
	testSendRawPort
 | 
			
		||||
	testDisconnectPort
 | 
			
		||||
	testSManDisconnectPort
 | 
			
		||||
	defaultChannelTimeout = 5 * time.Second
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHandshake(t *testing.T) {
 | 
			
		||||
@@ -48,16 +37,14 @@ func TestHandshake(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
// Tests connection process with a handshake exchange
 | 
			
		||||
// Tests multiple session IDs. All connections should generate a unique stream ID
 | 
			
		||||
func TestGenerateHandshake(t *testing.T) {
 | 
			
		||||
func TestGenerateHandshakeId(t *testing.T) {
 | 
			
		||||
	// Using this array with a channel to make a queue of values to test
 | 
			
		||||
	// These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate
 | 
			
		||||
	// some handshake value
 | 
			
		||||
	var uuidsArray = [5]string{
 | 
			
		||||
		"cc9b3249-9582-4780-825f-4311b42f9b0e",
 | 
			
		||||
		"bba8be3c-d98e-4e26-b9bb-9ed34578a503",
 | 
			
		||||
		"dae72822-80e8-496b-b763-ab685f53a188",
 | 
			
		||||
		"a45d6c06-de49-4bb0-935b-1a2201b71028",
 | 
			
		||||
		"7dc6924f-0eca-4237-9898-18654b8d891e",
 | 
			
		||||
	var uuidsArray = [5]string{}
 | 
			
		||||
	for i := 1; i < len(uuidsArray); i++ {
 | 
			
		||||
		id, _ := uuid.NewRandom()
 | 
			
		||||
		uuidsArray[i] = id.String()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Channel to pass stream IDs as a queue
 | 
			
		||||
@@ -95,7 +82,7 @@ func TestGenerateHandshake(t *testing.T) {
 | 
			
		||||
		Type:     "service",
 | 
			
		||||
	}
 | 
			
		||||
	router := NewRouter()
 | 
			
		||||
	c, err := NewComponent(opts, router)
 | 
			
		||||
	c, err := NewComponent(opts, router, componentDefaultErrorHandler)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -126,7 +113,7 @@ func TestStreamManager(t *testing.T) {
 | 
			
		||||
// The decoder is expected to be built after a valid connection
 | 
			
		||||
// Based on the xmpp_component example.
 | 
			
		||||
func TestDecoder(t *testing.T) {
 | 
			
		||||
	c, _ := mockConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID)
 | 
			
		||||
	c, _ := mockComponentConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID)
 | 
			
		||||
	if c.transport.GetDecoder() == nil {
 | 
			
		||||
		t.Errorf("Failed to initialize decoder. Decoder is nil.")
 | 
			
		||||
	}
 | 
			
		||||
@@ -134,39 +121,103 @@ func TestDecoder(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
// Tests sending an IQ to the server, and getting the response
 | 
			
		||||
func TestSendIq(t *testing.T) {
 | 
			
		||||
	done := make(chan struct{})
 | 
			
		||||
	h := func(t *testing.T, c net.Conn) {
 | 
			
		||||
		handlerForComponentIQSend(t, c)
 | 
			
		||||
		done <- struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//Connecting to a mock server, initialized with given port and handler function
 | 
			
		||||
	c, m := mockConnection(t, testSendIqPort, handlerForComponentIQSend)
 | 
			
		||||
	c, m := mockComponentConnection(t, testSendIqPort, h)
 | 
			
		||||
 | 
			
		||||
	ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
 | 
			
		||||
	iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
 | 
			
		||||
	disco := iqReq.DiscoInfo()
 | 
			
		||||
	iqReq.Payload = disco
 | 
			
		||||
 | 
			
		||||
	// Handle a possible error
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	errorHandler := func(err error) {
 | 
			
		||||
		errChan <- err
 | 
			
		||||
	}
 | 
			
		||||
	c.ErrorHandler = errorHandler
 | 
			
		||||
 | 
			
		||||
	var res chan stanza.IQ
 | 
			
		||||
	res, _ = c.SendIQ(ctx, iqReq)
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-res:
 | 
			
		||||
	case <-time.After(100 * time.Millisecond):
 | 
			
		||||
	case err := <-errChan:
 | 
			
		||||
		t.Errorf(err.Error())
 | 
			
		||||
	case <-time.After(defaultChannelTimeout):
 | 
			
		||||
		t.Errorf("Failed to receive response, to sent IQ, from mock server")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	m.Stop()
 | 
			
		||||
	select {
 | 
			
		||||
	case <-done:
 | 
			
		||||
		m.Stop()
 | 
			
		||||
	case <-time.After(defaultChannelTimeout):
 | 
			
		||||
		t.Errorf("The mock server failed to finish its job !")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Checking that error handling is done properly client side when an invalid IQ is sent and the server responds in kind.
 | 
			
		||||
func TestSendIqFail(t *testing.T) {
 | 
			
		||||
	done := make(chan struct{})
 | 
			
		||||
	h := func(t *testing.T, c net.Conn) {
 | 
			
		||||
		handlerForComponentIQSend(t, c)
 | 
			
		||||
		done <- struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	//Connecting to a mock server, initialized with given port and handler function
 | 
			
		||||
	c, m := mockComponentConnection(t, testSendIqFailPort, h)
 | 
			
		||||
 | 
			
		||||
	ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
 | 
			
		||||
	iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
 | 
			
		||||
 | 
			
		||||
	// Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified
 | 
			
		||||
	// so we need to overwrite it.
 | 
			
		||||
	iqReq.Id = ""
 | 
			
		||||
	disco := iqReq.DiscoInfo()
 | 
			
		||||
	iqReq.Payload = disco
 | 
			
		||||
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	errorHandler := func(err error) {
 | 
			
		||||
		errChan <- err
 | 
			
		||||
	}
 | 
			
		||||
	c.ErrorHandler = errorHandler
 | 
			
		||||
 | 
			
		||||
	var res chan stanza.IQ
 | 
			
		||||
	res, _ = c.SendIQ(ctx, iqReq)
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case r := <-res: // Do we get an IQ response from the server ?
 | 
			
		||||
		t.Errorf("We should not be getting an IQ response here : this should fail !")
 | 
			
		||||
		fmt.Println(r)
 | 
			
		||||
	case <-errChan: // Do we get a stream error from the server ?
 | 
			
		||||
		// If we get an error from the server, the test passes.
 | 
			
		||||
	case <-time.After(defaultChannelTimeout): // Timeout ?
 | 
			
		||||
		t.Errorf("Failed to receive response, to sent IQ, from mock server")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-done:
 | 
			
		||||
		m.Stop()
 | 
			
		||||
	case <-time.After(defaultChannelTimeout):
 | 
			
		||||
		t.Errorf("The mock server failed to finish its job !")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests sending raw xml to the mock server.
 | 
			
		||||
// TODO : check the server response client side ?
 | 
			
		||||
// Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err.
 | 
			
		||||
// In this test, we use IQs
 | 
			
		||||
func TestSendRaw(t *testing.T) {
 | 
			
		||||
	// Error channel for the handler
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	done := make(chan struct{})
 | 
			
		||||
	// Handler for the mock server
 | 
			
		||||
	h := func(t *testing.T, c net.Conn) {
 | 
			
		||||
		// Completes the connection by exchanging handshakes
 | 
			
		||||
		handlerForComponentHandshakeDefaultID(t, c)
 | 
			
		||||
		receiveRawIq(t, c, errChan)
 | 
			
		||||
		return
 | 
			
		||||
		receiveIq(c, xml.NewDecoder(c))
 | 
			
		||||
		done <- struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	type testCase struct {
 | 
			
		||||
@@ -185,12 +236,19 @@ func TestSendRaw(t *testing.T) {
 | 
			
		||||
		shouldErr: true,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// A handler for the component.
 | 
			
		||||
	// In the failing test, the server returns a stream error, which triggers this handler, component side.
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	errHandler := func(err error) {
 | 
			
		||||
		errChan <- err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Tests for all the IQs
 | 
			
		||||
	for name, tcase := range testRequests {
 | 
			
		||||
		t.Run(name, func(st *testing.T) {
 | 
			
		||||
			//Connecting to a mock server, initialized with given port and handler function
 | 
			
		||||
			c, m := mockConnection(t, testSendRawPort, h)
 | 
			
		||||
 | 
			
		||||
			c, m := mockComponentConnection(t, testSendRawPort, h)
 | 
			
		||||
			c.ErrorHandler = errHandler
 | 
			
		||||
			// Sending raw xml from test case
 | 
			
		||||
			err := c.SendRaw(tcase.req)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -198,21 +256,29 @@ func TestSendRaw(t *testing.T) {
 | 
			
		||||
			}
 | 
			
		||||
			// Just wait a little so the message has time to arrive
 | 
			
		||||
			select {
 | 
			
		||||
			case <-time.After(100 * time.Millisecond):
 | 
			
		||||
			// We don't use the default "long" timeout here because waiting it out means passing the test.
 | 
			
		||||
			case <-time.After(200 * time.Millisecond):
 | 
			
		||||
			case err = <-errChan:
 | 
			
		||||
				if err == nil && tcase.shouldErr {
 | 
			
		||||
					t.Errorf("Failed to get closing stream err")
 | 
			
		||||
				} else if err != nil && !tcase.shouldErr {
 | 
			
		||||
					t.Errorf("This test is not supposed to err ! => %s", err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			c.transport.Close()
 | 
			
		||||
			m.Stop()
 | 
			
		||||
			select {
 | 
			
		||||
			case <-done:
 | 
			
		||||
				m.Stop()
 | 
			
		||||
			case <-time.After(defaultChannelTimeout):
 | 
			
		||||
				t.Errorf("The mock server failed to finish its job !")
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests the Disconnect method for Components
 | 
			
		||||
func TestDisconnect(t *testing.T) {
 | 
			
		||||
	c, m := mockConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID)
 | 
			
		||||
	c, m := mockComponentConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID)
 | 
			
		||||
	err := c.transport.Ping()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Could not ping but not disconnected yet")
 | 
			
		||||
@@ -257,14 +323,97 @@ func TestStreamManagerDisconnect(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
//=============================================================================
 | 
			
		||||
// Basic XMPP Server Mock Handlers.
 | 
			
		||||
// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
 | 
			
		||||
// Used in the mock server as a Handler
 | 
			
		||||
func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) {
 | 
			
		||||
	decoder := xml.NewDecoder(c)
 | 
			
		||||
	checkOpenStreamHandshakeDefaultID(t, c, decoder)
 | 
			
		||||
	readHandshakeComponent(t, decoder)
 | 
			
		||||
	fmt.Fprintln(c, "<handshake/>") // That's all the server needs to return (see xep-0114)
 | 
			
		||||
	return
 | 
			
		||||
 | 
			
		||||
//===============================
 | 
			
		||||
// Init mock server and connection
 | 
			
		||||
// Creating a mock server and connecting a Component to it. Initialized with given port and handler function
 | 
			
		||||
// The Component and mock are both returned
 | 
			
		||||
func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) {
 | 
			
		||||
	// Init mock server
 | 
			
		||||
	testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port)
 | 
			
		||||
	mock := ServerMock{}
 | 
			
		||||
	mock.Start(t, testComponentAddress, handler)
 | 
			
		||||
 | 
			
		||||
	//==================================
 | 
			
		||||
	// Create Component to connect to it
 | 
			
		||||
	c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
 | 
			
		||||
 | 
			
		||||
	//========================================
 | 
			
		||||
	// Connect the new Component to the server
 | 
			
		||||
	err := c.Connect()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return c, &mock
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component {
 | 
			
		||||
	opts := ComponentOptions{
 | 
			
		||||
		TransportConfiguration: TransportConfiguration{
 | 
			
		||||
			Address: mockServerAddr,
 | 
			
		||||
			Domain:  "localhost",
 | 
			
		||||
		},
 | 
			
		||||
		Domain:   testComponentDomain,
 | 
			
		||||
		Secret:   "mypass",
 | 
			
		||||
		Name:     name,
 | 
			
		||||
		Category: "gateway",
 | 
			
		||||
		Type:     "service",
 | 
			
		||||
	}
 | 
			
		||||
	router := NewRouter()
 | 
			
		||||
	c, err := NewComponent(opts, router, componentDefaultErrorHandler)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
	c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This really should not be used as is.
 | 
			
		||||
// It's just meant to be a placeholder when error handling is not needed at this level
 | 
			
		||||
func componentDefaultErrorHandler(err error) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Sends IQ response to Component request.
 | 
			
		||||
// No parsing of the request here. We just check that it's valid, and send the default response.
 | 
			
		||||
func handlerForComponentIQSend(t *testing.T, c net.Conn) {
 | 
			
		||||
	// Completes the connection by exchanging handshakes
 | 
			
		||||
	handlerForComponentHandshakeDefaultID(t, c)
 | 
			
		||||
	respondToIQ(t, c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Used for ID and handshake related tests
 | 
			
		||||
func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) {
 | 
			
		||||
	c.SetDeadline(time.Now().Add(defaultTimeout))
 | 
			
		||||
	defer c.SetDeadline(time.Time{})
 | 
			
		||||
 | 
			
		||||
	for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion.
 | 
			
		||||
		token, err := decoder.Token()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("cannot read next token: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		switch elem := token.(type) {
 | 
			
		||||
		// Wait for first startElement
 | 
			
		||||
		case xml.StartElement:
 | 
			
		||||
			if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" {
 | 
			
		||||
				err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil {
 | 
			
		||||
				t.Errorf("cannot write server stream open: %s", err)
 | 
			
		||||
			}
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) {
 | 
			
		||||
	checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
 | 
			
		||||
@@ -303,152 +452,12 @@ func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) {
 | 
			
		||||
	checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Used for ID and handshake related tests
 | 
			
		||||
func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) {
 | 
			
		||||
	c.SetDeadline(time.Now().Add(defaultTimeout))
 | 
			
		||||
	defer c.SetDeadline(time.Time{})
 | 
			
		||||
 | 
			
		||||
	for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion.
 | 
			
		||||
		token, err := decoder.Token()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("cannot read next token: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		switch elem := token.(type) {
 | 
			
		||||
		// Wait for first startElement
 | 
			
		||||
		case xml.StartElement:
 | 
			
		||||
			if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" {
 | 
			
		||||
				err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil {
 | 
			
		||||
				t.Errorf("cannot write server stream open: %s", err)
 | 
			
		||||
			}
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//=============================================================================
 | 
			
		||||
// Sends IQ response to Component request.
 | 
			
		||||
// No parsing of the request here. We just check that it's valid, and send the default response.
 | 
			
		||||
func handlerForComponentIQSend(t *testing.T, c net.Conn) {
 | 
			
		||||
	// Completes the connection by exchanging handshakes
 | 
			
		||||
	handlerForComponentHandshakeDefaultID(t, c)
 | 
			
		||||
 | 
			
		||||
	// Decoder to parse the request
 | 
			
		||||
// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
 | 
			
		||||
// Used in the mock server as a Handler
 | 
			
		||||
func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) {
 | 
			
		||||
	decoder := xml.NewDecoder(c)
 | 
			
		||||
 | 
			
		||||
	iqReq, err := receiveIq(t, c, decoder)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Error receiving the IQ stanza : %v", err)
 | 
			
		||||
	} else if !iqReq.IsValid() {
 | 
			
		||||
		t.Errorf("server received an IQ stanza : %v", iqReq)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Crafting response
 | 
			
		||||
	iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"})
 | 
			
		||||
	disco := iqResp.DiscoInfo()
 | 
			
		||||
	disco.AddFeatures("vcard-temp",
 | 
			
		||||
		`http://jabber.org/protocol/address`)
 | 
			
		||||
 | 
			
		||||
	disco.AddIdentity("Multicast", "service", "multicast")
 | 
			
		||||
	iqResp.Payload = disco
 | 
			
		||||
 | 
			
		||||
	// Sending response to the Component
 | 
			
		||||
	mResp, err := xml.Marshal(iqResp)
 | 
			
		||||
	_, err = fmt.Fprintln(c, string(mResp))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Could not send response stanza : %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	checkOpenStreamHandshakeDefaultID(t, c, decoder)
 | 
			
		||||
	readHandshakeComponent(t, decoder)
 | 
			
		||||
	fmt.Fprintln(c, "<handshake/>") // That's all the server needs to return (see xep-0114)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Reads next request coming from the Component. Expecting it to be an IQ request
 | 
			
		||||
func receiveIq(t *testing.T, c net.Conn, decoder *xml.Decoder) (stanza.IQ, error) {
 | 
			
		||||
	c.SetDeadline(time.Now().Add(defaultTimeout))
 | 
			
		||||
	defer c.SetDeadline(time.Time{})
 | 
			
		||||
	var iqStz stanza.IQ
 | 
			
		||||
	err := decoder.Decode(&iqStz)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("cannot read the received IQ stanza: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	if !iqStz.IsValid() {
 | 
			
		||||
		t.Errorf("received IQ stanza is invalid : %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	return iqStz, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func receiveRawIq(t *testing.T, c net.Conn, errChan chan error) {
 | 
			
		||||
	c.SetDeadline(time.Now().Add(defaultTimeout))
 | 
			
		||||
	defer c.SetDeadline(time.Time{})
 | 
			
		||||
	decoder := xml.NewDecoder(c)
 | 
			
		||||
	var iq stanza.IQ
 | 
			
		||||
	err := decoder.Decode(&iq)
 | 
			
		||||
	if err != nil || !iq.IsValid() {
 | 
			
		||||
		s := stanza.StreamError{
 | 
			
		||||
			XMLName: xml.Name{Local: "stream:error"},
 | 
			
		||||
			Error:   xml.Name{Local: "xml-not-well-formed"},
 | 
			
		||||
			Text:    `XML was not well-formed`,
 | 
			
		||||
		}
 | 
			
		||||
		raw, _ := xml.Marshal(s)
 | 
			
		||||
		fmt.Fprintln(c, string(raw))
 | 
			
		||||
		fmt.Fprintln(c, `</stream:stream>`) // TODO : check this client side
 | 
			
		||||
		errChan <- fmt.Errorf("invalid xml")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	errChan <- nil
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===============================
 | 
			
		||||
// Init mock server and connection
 | 
			
		||||
// Creating a mock server and connecting a Component to it. Initialized with given port and handler function
 | 
			
		||||
// The Component and mock are both returned
 | 
			
		||||
func mockConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) {
 | 
			
		||||
	// Init mock server
 | 
			
		||||
	testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port)
 | 
			
		||||
	mock := ServerMock{}
 | 
			
		||||
	mock.Start(t, testComponentAddress, handler)
 | 
			
		||||
 | 
			
		||||
	//==================================
 | 
			
		||||
	// Create Component to connect to it
 | 
			
		||||
	c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
 | 
			
		||||
 | 
			
		||||
	//========================================
 | 
			
		||||
	// Connect the new Component to the server
 | 
			
		||||
	err := c.Connect()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return c, &mock
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component {
 | 
			
		||||
	opts := ComponentOptions{
 | 
			
		||||
		TransportConfiguration: TransportConfiguration{
 | 
			
		||||
			Address: mockServerAddr,
 | 
			
		||||
			Domain:  "localhost",
 | 
			
		||||
		},
 | 
			
		||||
		Domain:   testComponentDomain,
 | 
			
		||||
		Secret:   "mypass",
 | 
			
		||||
		Name:     name,
 | 
			
		||||
		Category: "gateway",
 | 
			
		||||
		Type:     "service",
 | 
			
		||||
	}
 | 
			
		||||
	router := NewRouter()
 | 
			
		||||
	c, err := NewComponent(opts, router)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
	c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("%+v", err)
 | 
			
		||||
	}
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,12 +1,42 @@
 | 
			
		||||
package xmpp
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/xml"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gosrc.io/xmpp/stanza"
 | 
			
		||||
	"net"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
//=============================================================================
 | 
			
		||||
// TCP Server Mock
 | 
			
		||||
const (
 | 
			
		||||
	defaultTimeout       = 2 * time.Second
 | 
			
		||||
	testComponentDomain  = "localhost"
 | 
			
		||||
	defaultServerName    = "testServer"
 | 
			
		||||
	defaultStreamID      = "91bd0bba-012f-4d92-bb17-5fc41e6fe545"
 | 
			
		||||
	defaultComponentName = "Test Component"
 | 
			
		||||
	serverStreamOpen     = "<?xml version='1.0'?><stream:stream to='%s' id='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
 | 
			
		||||
 | 
			
		||||
	// Default port is not standard XMPP port to avoid interfering
 | 
			
		||||
	// with local running XMPP server
 | 
			
		||||
 | 
			
		||||
	// Component tests
 | 
			
		||||
	testHandshakePort = iota + 15222
 | 
			
		||||
	testDecoderPort
 | 
			
		||||
	testSendIqPort
 | 
			
		||||
	testSendIqFailPort
 | 
			
		||||
	testSendRawPort
 | 
			
		||||
	testDisconnectPort
 | 
			
		||||
	testSManDisconnectPort
 | 
			
		||||
 | 
			
		||||
	// Client tests
 | 
			
		||||
	testClientBasePort
 | 
			
		||||
	testClientRawPort
 | 
			
		||||
	testClientIqPort
 | 
			
		||||
	testClientIqFailPort
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ClientHandler is passed by the test client to provide custom behaviour to
 | 
			
		||||
// the TCP server mock. This allows customizing the server behaviour to allow
 | 
			
		||||
@@ -81,3 +111,180 @@ func (mock *ServerMock) loop() {
 | 
			
		||||
		go mock.handler(mock.t, conn)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//======================================================================================================================
 | 
			
		||||
// A few functions commonly used for tests. Trying to avoid duplicates in client and component test files.
 | 
			
		||||
//======================================================================================================================
 | 
			
		||||
 | 
			
		||||
func respondToIQ(t *testing.T, c net.Conn) {
 | 
			
		||||
	// Decoder to parse the request
 | 
			
		||||
	decoder := xml.NewDecoder(c)
 | 
			
		||||
 | 
			
		||||
	iqReq, err := receiveIq(c, decoder)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("failed to receive IQ : %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !iqReq.IsValid() {
 | 
			
		||||
		mockIQError(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Crafting response
 | 
			
		||||
	iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"})
 | 
			
		||||
	disco := iqResp.DiscoInfo()
 | 
			
		||||
	disco.AddFeatures("vcard-temp",
 | 
			
		||||
		`http://jabber.org/protocol/address`)
 | 
			
		||||
 | 
			
		||||
	disco.AddIdentity("Multicast", "service", "multicast")
 | 
			
		||||
	iqResp.Payload = disco
 | 
			
		||||
 | 
			
		||||
	// Sending response to the Component
 | 
			
		||||
	mResp, err := xml.Marshal(iqResp)
 | 
			
		||||
	_, err = fmt.Fprintln(c, string(mResp))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Could not send response stanza : %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// When a presence stanza is automatically sent (right now it's the case in the client), we may want to discard it
 | 
			
		||||
// and test further stanzas.
 | 
			
		||||
func discardPresence(t *testing.T, c net.Conn) {
 | 
			
		||||
	decoder := xml.NewDecoder(c)
 | 
			
		||||
	c.SetDeadline(time.Now().Add(defaultTimeout))
 | 
			
		||||
	defer c.SetDeadline(time.Time{})
 | 
			
		||||
	var presenceStz stanza.Presence
 | 
			
		||||
	err := decoder.Decode(&presenceStz)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Expected presence but this happened : %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Reads next request coming from the Component. Expecting it to be an IQ request
 | 
			
		||||
func receiveIq(c net.Conn, decoder *xml.Decoder) (*stanza.IQ, error) {
 | 
			
		||||
	c.SetDeadline(time.Now().Add(defaultTimeout))
 | 
			
		||||
	defer c.SetDeadline(time.Time{})
 | 
			
		||||
	var iqStz stanza.IQ
 | 
			
		||||
	err := decoder.Decode(&iqStz)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &iqStz, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Should be used in server handlers when an IQ sent by a client or component is invalid.
 | 
			
		||||
// This responds as expected from a "real" server, aside from the error message.
 | 
			
		||||
func mockIQError(c net.Conn) {
 | 
			
		||||
	s := stanza.StreamError{
 | 
			
		||||
		XMLName: xml.Name{Local: "stream:error"},
 | 
			
		||||
		Error:   xml.Name{Local: "xml-not-well-formed"},
 | 
			
		||||
		Text:    `XML was not well-formed`,
 | 
			
		||||
	}
 | 
			
		||||
	raw, _ := xml.Marshal(s)
 | 
			
		||||
	fmt.Fprintln(c, string(raw))
 | 
			
		||||
	fmt.Fprintln(c, `</stream:stream>`)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) {
 | 
			
		||||
	// This is a basic server, supporting only 1 stream feature: SASL Plain Auth
 | 
			
		||||
	features := `<stream:features>
 | 
			
		||||
  <mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl">
 | 
			
		||||
    <mechanism>PLAIN</mechanism>
 | 
			
		||||
  </mechanisms>
 | 
			
		||||
</stream:features>`
 | 
			
		||||
	if _, err := fmt.Fprintln(c, features); err != nil {
 | 
			
		||||
		t.Errorf("cannot send stream feature: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO return err in case of error reading the auth params
 | 
			
		||||
func readAuth(t *testing.T, decoder *xml.Decoder) string {
 | 
			
		||||
	se, err := stanza.NextStart(decoder)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("cannot read auth: %s", err)
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var nv interface{}
 | 
			
		||||
	nv = &stanza.SASLAuth{}
 | 
			
		||||
	// Decode element into pointer storage
 | 
			
		||||
	if err = decoder.DecodeElement(nv, &se); err != nil {
 | 
			
		||||
		t.Errorf("cannot decode auth: %s", err)
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch v := nv.(type) {
 | 
			
		||||
	case *stanza.SASLAuth:
 | 
			
		||||
		return v.Value
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) {
 | 
			
		||||
	// This is a basic server, supporting only 1 stream feature after auth: resource binding
 | 
			
		||||
	features := `<stream:features>
 | 
			
		||||
  <bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>
 | 
			
		||||
</stream:features>`
 | 
			
		||||
	if _, err := fmt.Fprintln(c, features); err != nil {
 | 
			
		||||
		t.Errorf("cannot send stream feature: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) {
 | 
			
		||||
	// This is a basic server, supporting only 2 features after auth: resource & session binding
 | 
			
		||||
	features := `<stream:features>
 | 
			
		||||
  <bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>
 | 
			
		||||
  <session xmlns='urn:ietf:params:xml:ns:xmpp-session'/>
 | 
			
		||||
</stream:features>`
 | 
			
		||||
	if _, err := fmt.Fprintln(c, features); err != nil {
 | 
			
		||||
		t.Errorf("cannot send stream feature: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) {
 | 
			
		||||
	se, err := stanza.NextStart(decoder)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("cannot read bind: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	iq := &stanza.IQ{}
 | 
			
		||||
	// Decode element into pointer storage
 | 
			
		||||
	if err = decoder.DecodeElement(&iq, &se); err != nil {
 | 
			
		||||
		t.Errorf("cannot decode bind iq: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO Check all elements
 | 
			
		||||
	switch iq.Payload.(type) {
 | 
			
		||||
	case *stanza.Bind:
 | 
			
		||||
		result := `<iq id='%s' type='result'>
 | 
			
		||||
  <bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'>
 | 
			
		||||
  	<jid>%s</jid>
 | 
			
		||||
  </bind>
 | 
			
		||||
</iq>`
 | 
			
		||||
		fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func session(t *testing.T, c net.Conn, decoder *xml.Decoder) {
 | 
			
		||||
	se, err := stanza.NextStart(decoder)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("cannot read session: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	iq := &stanza.IQ{}
 | 
			
		||||
	// Decode element into pointer storage
 | 
			
		||||
	if err = decoder.DecodeElement(&iq, &se); err != nil {
 | 
			
		||||
		t.Errorf("cannot decode session iq: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch iq.Payload.(type) {
 | 
			
		||||
	case *stanza.StreamSession:
 | 
			
		||||
		result := `<iq id='%s' type='result'/>`
 | 
			
		||||
		fmt.Fprintf(c, result, iq.Id)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user