forked from jshiffer/go-xmpp
		
	Add a go function to always read websockets
Websocket need to have a Reader running at all times in order to allow Ping to work (because a Reader is the only thing that will correctly handle control frames). To faciliate this a go function is introduced that will always read from the websocket until it is cancelled. Read data is passed to the transport via a channel.
This commit is contained in:
		 Wichert Akkerman
					Wichert Akkerman
				
			
				
					committed by
					
						 Mickaël Rémond
						Mickaël Rémond
					
				
			
			
				
	
			
			
			 Mickaël Rémond
						Mickaël Rémond
					
				
			
						parent
						
							92329b48e6
						
					
				
				
					commit
					ffadd331dd
				
			| @@ -6,7 +6,6 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| @@ -14,6 +13,8 @@ import ( | ||||
| 	"nhooyr.io/websocket" | ||||
| ) | ||||
|  | ||||
| const maxPacketSize = 32768 | ||||
|  | ||||
| const pingTimeout = time.Duration(5) * time.Second | ||||
|  | ||||
| var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server does not support the xmpp subprotocol") | ||||
| @@ -22,17 +23,23 @@ type WebsocketTransport struct { | ||||
| 	Config  TransportConfiguration | ||||
| 	decoder *xml.Decoder | ||||
| 	wsConn  *websocket.Conn | ||||
| 	netConn net.Conn | ||||
| 	queue   chan []byte | ||||
| 	logFile io.Writer | ||||
|  | ||||
| 	closeCtx  context.Context | ||||
| 	closeFunc context.CancelFunc | ||||
| } | ||||
|  | ||||
| func (t *WebsocketTransport) Connect() (string, error) { | ||||
| 	ctx := context.Background() | ||||
| 	t.queue = make(chan []byte, 256) | ||||
| 	t.closeCtx, t.closeFunc = context.WithCancel(context.Background()) | ||||
|  | ||||
| 	var ctx context.Context | ||||
| 	ctx = context.Background() | ||||
| 	if t.Config.ConnectTimeout > 0 { | ||||
| 		var cancel context.CancelFunc | ||||
| 		ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Config.ConnectTimeout)*time.Second) | ||||
| 		defer cancel() | ||||
| 		var cancelConnect context.CancelFunc | ||||
| 		ctx, cancelConnect = context.WithTimeout(t.closeCtx, time.Duration(t.Config.ConnectTimeout)*time.Second) | ||||
| 		defer cancelConnect() | ||||
| 	} | ||||
|  | ||||
| 	wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{ | ||||
| @@ -42,28 +49,30 @@ func (t *WebsocketTransport) Connect() (string, error) { | ||||
| 		return "", NewConnError(err, true) | ||||
| 	} | ||||
| 	if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" { | ||||
| 		_ = wsConn.Close(websocket.StatusBadGateway, "Could not negotiate XMPP subprotocol") | ||||
| 		t.cleanup(websocket.StatusBadGateway) | ||||
| 		return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true) | ||||
| 	} | ||||
|  | ||||
| 	wsConn.SetReadLimit(maxPacketSize) | ||||
| 	t.wsConn = wsConn | ||||
| 	t.netConn = websocket.NetConn(ctx, t.wsConn, websocket.MessageText) | ||||
| 	t.startReader() | ||||
|  | ||||
| 	handshake := fmt.Sprintf("<open xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" to=\"%s\" version=\"1.0\" />", t.Config.Domain) | ||||
| 	handshake := fmt.Sprintf(`<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain) | ||||
| 	if _, err = t.Write([]byte(handshake)); err != nil { | ||||
| 		_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error") | ||||
| 		t.cleanup(websocket.StatusBadGateway) | ||||
| 		return "", NewConnError(err, false) | ||||
| 	} | ||||
|  | ||||
| 	handshakeResponse := make([]byte, 2048) | ||||
| 	if _, err = t.Read(handshakeResponse); err != nil { | ||||
| 		_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error") | ||||
| 		t.cleanup(websocket.StatusBadGateway) | ||||
|  | ||||
| 		return "", NewConnError(err, false) | ||||
| 	} | ||||
|  | ||||
| 	var openResponse = stanza.WebsocketOpen{} | ||||
| 	if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil { | ||||
| 		_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error") | ||||
| 		t.cleanup(websocket.StatusBadGateway) | ||||
| 		return "", NewConnError(err, false) | ||||
| 	} | ||||
|  | ||||
| @@ -73,6 +82,32 @@ func (t *WebsocketTransport) Connect() (string, error) { | ||||
| 	return openResponse.Id, nil | ||||
| } | ||||
|  | ||||
| // startReader runs a go function that keeps reading from the websocket. This | ||||
| // is required to allow Ping() to work: Ping requires a Reader to be running | ||||
| // to process incoming control frames. | ||||
| func (t WebsocketTransport) startReader() { | ||||
| 	go func() { | ||||
| 		buffer := make([]byte, maxPacketSize) | ||||
| 		for { | ||||
| 			_, reader, err := t.wsConn.Reader(t.closeCtx) | ||||
| 			if err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 			n, err := reader.Read(buffer) | ||||
| 			if err != nil && err != io.EOF { | ||||
| 				return | ||||
| 			} | ||||
| 			if n > 0 { | ||||
| 				// We need to make a copy, otherwise we will overwrite the slice content | ||||
| 				// on the next iteration of the for loop. | ||||
| 				tmp := make([]byte, len(buffer)) | ||||
| 				copy(tmp, buffer) | ||||
| 				t.queue <- tmp | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (t WebsocketTransport) StartTLS() error { | ||||
| 	return TLSNotSupported | ||||
| } | ||||
| @@ -90,31 +125,52 @@ func (t WebsocketTransport) IsSecure() bool { | ||||
| } | ||||
|  | ||||
| func (t WebsocketTransport) Ping() error { | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) | ||||
| 	ctx, cancel := context.WithTimeout(t.closeCtx, pingTimeout) | ||||
| 	defer cancel() | ||||
| 	return t.wsConn.Ping(ctx) | ||||
| } | ||||
|  | ||||
| func (t *WebsocketTransport) Read(p []byte) (n int, err error) { | ||||
| 	n, err = t.netConn.Read(p) | ||||
| 	if t.logFile != nil && n > 0 { | ||||
| 		_, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", p) | ||||
| func (t *WebsocketTransport) Read(p []byte) (int, error) { | ||||
| 	select { | ||||
| 	case <-t.closeCtx.Done(): | ||||
| 		return 0, t.closeCtx.Err() | ||||
| 	case data := <-t.queue: | ||||
| 		if t.logFile != nil && len(data) > 0 { | ||||
| 			_, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", data) | ||||
| 		} | ||||
| 		copy(p, data) | ||||
| 		return len(data), nil | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (t WebsocketTransport) Write(p []byte) (n int, err error) { | ||||
| func (t WebsocketTransport) Write(p []byte) (int, error) { | ||||
| 	if t.logFile != nil { | ||||
| 		_, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p) | ||||
| 	} | ||||
| 	return t.netConn.Write(p) | ||||
| 	return len(p), t.wsConn.Write(t.closeCtx, websocket.MessageText, p) | ||||
| } | ||||
|  | ||||
| func (t WebsocketTransport) Close() error { | ||||
| 	t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />")) | ||||
| 	return t.netConn.Close() | ||||
| 	return t.wsConn.Close(websocket.StatusGoingAway, "Done") | ||||
| } | ||||
|  | ||||
| func (t *WebsocketTransport) LogTraffic(logFile io.Writer) { | ||||
| 	t.logFile = logFile | ||||
| } | ||||
|  | ||||
| func (t *WebsocketTransport) cleanup(code websocket.StatusCode) { | ||||
| 	if t.queue != nil { | ||||
| 		close(t.queue) | ||||
| 		t.queue = nil | ||||
| 	} | ||||
| 	if t.wsConn != nil { | ||||
| 		t.wsConn.Close(websocket.StatusGoingAway, "Done") | ||||
| 		t.wsConn = nil | ||||
| 	} | ||||
| 	if t.closeFunc != nil { | ||||
| 		t.closeFunc() | ||||
| 		t.closeFunc = nil | ||||
| 		t.closeCtx = nil | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user