package xmpp

import (
	"encoding/xml"
	"reflect"
)

// ============================================================================
// Message Packet

// Message implements RFC 6120 - A.5 Client Namespace (a part)
type Message struct {
	XMLName xml.Name `xml:"message"`
	Attrs

	Subject    string         `xml:"subject,omitempty"`
	Body       string         `xml:"body,omitempty"`
	Thread     string         `xml:"thread,omitempty"`
	Error      Err            `xml:"error,omitempty"`
	Extensions []MsgExtension `xml:",omitempty"`
}

func (Message) Name() string {
	return "message"
}

func NewMessage(a Attrs) Message {
	return Message{
		XMLName: xml.Name{Local: "message"},
		Attrs:   a,
	}
}

// Get search and extracts a specific extension on a message.
// It receives a pointer to an MsgExtension. It will panic if the caller
// does not pass a pointer.
// It will return true if the passed extension is found and set the pointer
// to the extension passed as parameter to the found extension.
// It will return false if the extension is not found on the message.
//
// Example usage:
//   var oob xmpp.OOB
//   if ok := msg.Get(&oob); ok {
//     // oob extension has been found
//	 }
func (msg *Message) Get(ext MsgExtension) bool {
	target := reflect.ValueOf(ext)
	if target.Kind() != reflect.Ptr {
		panic("you must pass a pointer to the message Get method")
	}

	for _, e := range msg.Extensions {
		if reflect.TypeOf(e) == target.Type() {
			source := reflect.ValueOf(e)
			if source.Kind() != reflect.Ptr {
				source = source.Elem()
			}
			target.Elem().Set(source.Elem())
			return true
		}
	}
	return false
}

type messageDecoder struct{}

var message messageDecoder

func (messageDecoder) decode(p *xml.Decoder, se xml.StartElement) (Message, error) {
	var packet Message
	err := p.DecodeElement(&packet, &se)
	return packet, err
}

// XMPPFormat with all Extensions
func (msg *Message) XMPPFormat() string {
	out, err := xml.MarshalIndent(msg, "", "")
	if err != nil {
		return ""
	}
	return string(out)
}

// UnmarshalXML implements custom parsing for messages
func (msg *Message) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
	msg.XMLName = start.Name

	// Extract packet attributes
	for _, attr := range start.Attr {
		if attr.Name.Local == "id" {
			msg.Id = attr.Value
		}
		if attr.Name.Local == "type" {
			msg.Type = StanzaType(attr.Value)
		}
		if attr.Name.Local == "to" {
			msg.To = attr.Value
		}
		if attr.Name.Local == "from" {
			msg.From = attr.Value
		}
		if attr.Name.Local == "lang" {
			msg.Lang = attr.Value
		}
	}

	// decode inner elements
	for {
		t, err := d.Token()
		if err != nil {
			return err
		}

		switch tt := t.(type) {

		case xml.StartElement:
			if msgExt := TypeRegistry.GetMsgExtension(tt.Name); msgExt != nil {
				// Decode message extension
				err = d.DecodeElement(msgExt, &tt)
				if err != nil {
					return err
				}
				msg.Extensions = append(msg.Extensions, msgExt)
			} else {
				// Decode standard message sub-elements
				var err error
				switch tt.Name.Local {
				case "body":
					err = d.DecodeElement(&msg.Body, &tt)
				case "thread":
					err = d.DecodeElement(&msg.Thread, &tt)
				case "subject":
					err = d.DecodeElement(&msg.Subject, &tt)
				case "error":
					err = d.DecodeElement(&msg.Error, &tt)
				}
				if err != nil {
					return err
				}
			}

		case xml.EndElement:
			if tt == start.End() {
				return nil
			}
		}
	}
}