From 3f4867294631abd45a8c4137c142af3e30bb9893 Mon Sep 17 00:00:00 2001
From: Mickael Remond <mremond@process-one.net>
Date: Wed, 31 Jul 2019 18:47:30 +0200
Subject: [PATCH] Add initial support for stream management

For now it support enabling SM, replying to ack requests from server,
and trying resuming the session with existing Stream Management state.
---
 client.go                   |  60 +++++++++++++++---
 component.go                |   4 ++
 session.go                  |  75 +++++++++++++++++++++-
 stanza/parser.go            |   4 +-
 stanza/stream_management.go | 121 ++++++++++++++++++++++++++++++++++++
 stream_manager.go           |  12 +++-
 6 files changed, 261 insertions(+), 15 deletions(-)
 create mode 100644 stanza/stream_management.go

diff --git a/client.go b/client.go
index 7bb7a80..7e40cb4 100644
--- a/client.go
+++ b/client.go
@@ -31,6 +31,18 @@ type Event struct {
 	State       ConnState
 	Description string
 	StreamError string
+	SMState     SMState
+}
+
+// SMState holds Stream Management information regarding the session that can be
+// used to resume session after disconnect
+type SMState struct {
+	// Stream Management ID
+	Id string
+	// Inbound stanza count
+	Inbound uint
+	// TODO Store location for IP affinity
+	// TODO Store max and timestamp, to check if we should retry resumption or not
 }
 
 // EventHandler is use to pass events about state of the connection to
@@ -52,6 +64,13 @@ func (em EventManager) updateState(state ConnState) {
 	}
 }
 
+func (em EventManager) disconnected(state SMState) {
+	em.CurrentState = StateDisconnected
+	if em.Handler != nil {
+		em.Handler(Event{State: em.CurrentState, SMState: state})
+	}
+}
+
 func (em EventManager) streamError(error, desc string) {
 	em.CurrentState = StateStreamError
 	if em.Handler != nil {
@@ -128,7 +147,15 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
 }
 
 // Connect triggers actual TCP connection, based on previously defined parameters.
+// Connect simply triggers resumption, with an empty session state.
 func (c *Client) Connect() error {
+	var state SMState
+	return c.Resume(state)
+}
+
+// Resume attempts resuming  a Stream Managed session, based on the provided stream management
+// state.
+func (c *Client) Resume(state SMState) error {
 	var err error
 
 	c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second)
@@ -138,23 +165,24 @@ func (c *Client) Connect() error {
 	c.updateState(StateConnected)
 
 	// Client is ok, we now open XMPP session
-	if c.conn, c.Session, err = NewSession(c.conn, c.config); err != nil {
+	if c.conn, c.Session, err = NewSession(c.conn, c.config, state); err != nil {
 		return err
 	}
 	c.updateState(StateSessionEstablished)
 
+	// Start the keepalive go routine
+	keepaliveQuit := make(chan struct{})
+	go keepalive(c.conn, keepaliveQuit)
+	// Start the receiver go routine
+	state = c.Session.SMState
+	go c.recv(state, keepaliveQuit)
+
 	// We're connected and can now receive and send messages.
 	//fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online")
 	// TODO: Do we always want to send initial presence automatically ?
 	// Do we need an option to avoid that or do we rely on client to send the presence itself ?
 	fmt.Fprintf(c.Session.streamLogger, "<presence/>")
 
-	// Start the keepalive go routine
-	keepaliveQuit := make(chan struct{})
-	go keepalive(c.conn, keepaliveQuit)
-	// Start the receiver go routine
-	go c.recv(keepaliveQuit)
-
 	return err
 }
 
@@ -206,12 +234,12 @@ func (c *Client) sendWithLogger(packet string) error {
 // Go routines
 
 // Loop: Receive data from server
-func (c *Client) recv(keepaliveQuit chan<- struct{}) (err error) {
+func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) {
 	for {
 		val, err := stanza.NextPacket(c.Session.decoder)
 		if err != nil {
 			close(keepaliveQuit)
-			c.updateState(StateDisconnected)
+			c.disconnected(state)
 			return err
 		}
 
@@ -222,6 +250,17 @@ func (c *Client) recv(keepaliveQuit chan<- struct{}) (err error) {
 			close(keepaliveQuit)
 			c.streamError(packet.Error.Local, packet.Text)
 			return errors.New("stream error: " + packet.Error.Local)
+		// Process Stream management nonzas
+		case stanza.SMRequest:
+			fmt.Println("MREMOND: inbound: ", state.Inbound)
+			answer := stanza.SMAnswer{XMLName: xml.Name{
+				Space: stanza.NSStreamManagement,
+				Local: "a",
+			}, H: state.Inbound}
+			c.Send(answer)
+		default:
+			fmt.Println(packet)
+			state.Inbound++
 		}
 
 		c.router.route(c, val)
@@ -243,6 +282,9 @@ func keepalive(conn net.Conn, quit <-chan struct{}) {
 				_ = conn.Close()
 				return
 			}
+		case <-time.After(3 * time.Second):
+			_ = conn.Close()
+			return
 		case <-quit:
 			ticker.Stop()
 			return
diff --git a/component.go b/component.go
index 0176371..af424a2 100644
--- a/component.go
+++ b/component.go
@@ -108,6 +108,10 @@ func (c *Component) Connect() error {
 	}
 }
 
+func (c *Component) Resume() error {
+	return errors.New("components do not support stream management")
+}
+
 func (c *Component) Disconnect() {
 	_ = c.SendRaw("</stream:stream>")
 	// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
diff --git a/session.go b/session.go
index 1092569..8253ab6 100644
--- a/session.go
+++ b/session.go
@@ -17,6 +17,7 @@ type Session struct {
 	// Session info
 	BindJid      string // Jabber ID as provided by XMPP server
 	StreamId     string
+	SMState      SMState
 	Features     stanza.StreamFeatures
 	TlsEnabled   bool
 	lastPacketId int
@@ -29,8 +30,9 @@ type Session struct {
 	err error
 }
 
-func NewSession(conn net.Conn, o Config) (net.Conn, *Session, error) {
+func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) {
 	s := new(Session)
+	s.SMState = state
 	s.init(conn, o)
 
 	// starttls
@@ -54,10 +56,18 @@ func NewSession(conn net.Conn, o Config) (net.Conn, *Session, error) {
 	s.auth(o)
 	s.reset(tlsConn, tlsConn, o)
 
-	// bind resource and 'start' XMPP session
+	// attempt resumption
+	if s.resume(o) {
+		return tlsConn, s, s.err
+	}
+
+	// otherwise, bind resource and 'start' XMPP session
 	s.bind(o)
 	s.rfc3921Session(o)
 
+	// Enable stream management if supported
+	s.EnableStreamManagement(o)
+
 	return tlsConn, s, s.err
 }
 
@@ -161,6 +171,39 @@ func (s *Session) auth(o Config) {
 	s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Password)
 }
 
+// Attempt to resume session using stream management
+func (s *Session) resume(o Config) bool {
+	if !s.Features.DoesStreamManagement() {
+		return false
+	}
+	if s.SMState.Id == "" {
+		return false
+	}
+
+	fmt.Fprintf(s.streamLogger, "<resume xmlns='%s' h='%d' previd='%s'/>",
+		stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id)
+
+	var packet stanza.Packet
+	packet, s.err = stanza.NextPacket(s.decoder)
+	if s.err == nil {
+		switch p := packet.(type) {
+		case stanza.SMResumed:
+			if p.PrevId != s.SMState.Id {
+				s.err = errors.New("session resumption: mismatched id")
+				s.SMState = SMState{}
+				return false
+			}
+			return true
+		case stanza.SMFailed:
+			fmt.Println("MREMOND SM Failed")
+		default:
+			s.err = errors.New("unexpected reply to SM resume")
+		}
+	}
+	s.SMState = SMState{}
+	return false
+}
+
 func (s *Session) bind(o Config) {
 	if s.err != nil {
 		return
@@ -208,3 +251,31 @@ func (s *Session) rfc3921Session(o Config) {
 		}
 	}
 }
+
+// Enable stream management, with session resumption, if supported.
+func (s *Session) EnableStreamManagement(o Config) {
+	if s.err != nil {
+		return
+	}
+	if !s.Features.DoesStreamManagement() {
+		return
+	}
+
+	fmt.Fprintf(s.streamLogger, "<enable xmlns='%s' resume='true'/>", stanza.NSStreamManagement)
+
+	var packet stanza.Packet
+	packet, s.err = stanza.NextPacket(s.decoder)
+	if s.err == nil {
+		switch p := packet.(type) {
+		case stanza.SMEnabled:
+			s.SMState = SMState{Id: p.Id}
+		case stanza.SMFailed:
+			// TODO: Store error in SMState
+		default:
+			fmt.Println(p)
+			s.err = errors.New("unexpected reply to SM enable")
+		}
+	}
+
+	return
+}
diff --git a/stanza/parser.go b/stanza/parser.go
index c83e17e..cdd8b70 100644
--- a/stanza/parser.go
+++ b/stanza/parser.go
@@ -63,6 +63,8 @@ func NextPacket(p *xml.Decoder) (Packet, error) {
 		return decodeClient(p, se)
 	case NSComponent:
 		return decodeComponent(p, se)
+	case NSStreamManagement:
+		return sm.decode(p, se)
 	default:
 		return nil, errors.New("unknown namespace " +
 			se.Name.Space + " <" + se.Name.Local + "/>")
@@ -133,7 +135,7 @@ func decodeClient(p *xml.Decoder, se xml.StartElement) (Packet, error) {
 	}
 }
 
-// decodeClient decodes all known packets in the component namespace.
+// decodeComponent decodes all known packets in the component namespace.
 func decodeComponent(p *xml.Decoder, se xml.StartElement) (Packet, error) {
 	switch se.Name.Local {
 	case "handshake": // handshake is used to authenticate components
diff --git a/stanza/stream_management.go b/stanza/stream_management.go
new file mode 100644
index 0000000..ddbe9cd
--- /dev/null
+++ b/stanza/stream_management.go
@@ -0,0 +1,121 @@
+package stanza
+
+import (
+	"encoding/xml"
+	"errors"
+)
+
+const (
+	NSStreamManagement = "urn:xmpp:sm:3"
+)
+
+// Enabled as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#enable
+type SMEnabled struct {
+	XMLName  xml.Name `xml:"urn:xmpp:sm:3 enabled"`
+	Id       string   `xml:"id,attr,omitempty"`
+	Location string   `xml:"location,attr,omitempty"`
+	Resume   string   `xml:"resume,attr,omitempty"`
+	Max      uint     `xml:"max,attr,omitempty"`
+}
+
+func (SMEnabled) Name() string {
+	return "Stream Management: enabled"
+}
+
+// Request as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#acking
+type SMRequest struct {
+	XMLName xml.Name `xml:"urn:xmpp:sm:3 r"`
+}
+
+func (SMRequest) Name() string {
+	return "Stream Management: request"
+}
+
+// Answer as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#acking
+type SMAnswer struct {
+	XMLName xml.Name `xml:"urn:xmpp:sm:3 a"`
+	H       uint     `xml:"h,attr,omitempty"`
+}
+
+func (SMAnswer) Name() string {
+	return "Stream Management: answer"
+}
+
+// Resumed as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#acking
+type SMResumed struct {
+	XMLName xml.Name `xml:"urn:xmpp:sm:3 resumed"`
+	PrevId  string   `xml:"previd,attr,omitempty"`
+	H       uint     `xml:"h,attr,omitempty"`
+}
+
+func (SMResumed) Name() string {
+	return "Stream Management: resumed"
+}
+
+// Failed as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#acking
+type SMFailed struct {
+	XMLName xml.Name `xml:"urn:xmpp:sm:3 failed"`
+	// TODO: Handle decoding error cause (need custom parsing).
+}
+
+func (SMFailed) Name() string {
+	return "Stream Management: failed"
+}
+
+type smDecoder struct{}
+
+var sm smDecoder
+
+// decode decodes all known nonza in the stream management namespace.
+func (s smDecoder) decode(p *xml.Decoder, se xml.StartElement) (Packet, error) {
+	switch se.Name.Local {
+	case "enabled":
+		return s.decodeEnabled(p, se)
+	case "resumed":
+		return s.decodeResumed(p, se)
+	case "r":
+		return s.decodeRequest(p, se)
+	case "h":
+		return s.decodeAnswer(p, se)
+	case "failed":
+		return s.decodeFailed(p, se)
+	default:
+		return nil, errors.New("unexpected XMPP packet " +
+			se.Name.Space + " <" + se.Name.Local + "/>")
+	}
+}
+
+func (smDecoder) decodeEnabled(p *xml.Decoder, se xml.StartElement) (SMEnabled, error) {
+	var packet SMEnabled
+	err := p.DecodeElement(&packet, &se)
+	return packet, err
+}
+
+func (smDecoder) decodeResumed(p *xml.Decoder, se xml.StartElement) (SMResumed, error) {
+	var packet SMResumed
+	err := p.DecodeElement(&packet, &se)
+	return packet, err
+}
+
+func (smDecoder) decodeRequest(p *xml.Decoder, se xml.StartElement) (SMRequest, error) {
+	var packet SMRequest
+	err := p.DecodeElement(&packet, &se)
+	return packet, err
+}
+
+func (smDecoder) decodeAnswer(p *xml.Decoder, se xml.StartElement) (SMAnswer, error) {
+	var packet SMAnswer
+	err := p.DecodeElement(&packet, &se)
+	return packet, err
+}
+
+func (smDecoder) decodeFailed(p *xml.Decoder, se xml.StartElement) (SMFailed, error) {
+	var packet SMFailed
+	err := p.DecodeElement(&packet, &se)
+	return packet, err
+}
diff --git a/stream_manager.go b/stream_manager.go
index 1aaf164..b81a783 100644
--- a/stream_manager.go
+++ b/stream_manager.go
@@ -24,6 +24,7 @@ import (
 // set callback and trigger reconnection.
 type StreamClient interface {
 	Connect() error
+	Resume(state SMState) error
 	Send(packet stanza.Packet) error
 	SendRaw(packet string) error
 	Disconnect()
@@ -78,7 +79,7 @@ func (sm *StreamManager) Run() error {
 			sm.Metrics.setLoginTime()
 		case StateDisconnected:
 			// Reconnect on disconnection
-			sm.connect()
+			sm.resume(e.SMState)
 		case StateStreamError:
 			sm.client.Disconnect()
 			// Only try reconnecting if we have not been kicked by another session to avoid connection loop.
@@ -106,8 +107,13 @@ func (sm *StreamManager) Stop() {
 	sm.wg.Done()
 }
 
-// connect manages the reconnection loop and apply the define backoff to avoid overloading the server.
 func (sm *StreamManager) connect() error {
+	var state SMState
+	return sm.resume(state)
+}
+
+// resume manages the reconnection loop and apply the define backoff to avoid overloading the server.
+func (sm *StreamManager) resume(state SMState) error {
 	var backoff backoff // TODO: Group backoff calculation features with connection manager?
 
 	for {
@@ -115,7 +121,7 @@ func (sm *StreamManager) connect() error {
 		// TODO: Make it possible to define logger to log disconnect and reconnection attempts
 		sm.Metrics = initMetrics()
 
-		if err = sm.client.Connect(); err != nil {
+		if err = sm.client.Resume(state); err != nil {
 			var actualErr ConnError
 			if xerrors.As(err, &actualErr) {
 				if actualErr.Permanent {