forked from jshiffer/matterbridge
340 lines
6.2 KiB
Go
340 lines
6.2 KiB
Go
|
package msgpack
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
|
||
|
"github.com/vmihailenco/msgpack/v5/msgpcode"
|
||
|
)
|
||
|
|
||
|
var errArrayStruct = errors.New("msgpack: number of fields in array-encoded struct has changed")
|
||
|
|
||
|
var (
|
||
|
mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
|
||
|
mapStringStringType = mapStringStringPtrType.Elem()
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil))
|
||
|
mapStringInterfaceType = mapStringInterfacePtrType.Elem()
|
||
|
)
|
||
|
|
||
|
func decodeMapValue(d *Decoder, v reflect.Value) error {
|
||
|
n, err := d.DecodeMapLen()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
typ := v.Type()
|
||
|
if n == -1 {
|
||
|
v.Set(reflect.Zero(typ))
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
if v.IsNil() {
|
||
|
v.Set(reflect.MakeMap(typ))
|
||
|
}
|
||
|
if n == 0 {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
return d.decodeTypedMapValue(v, n)
|
||
|
}
|
||
|
|
||
|
func (d *Decoder) decodeMapDefault() (interface{}, error) {
|
||
|
if d.mapDecoder != nil {
|
||
|
return d.mapDecoder(d)
|
||
|
}
|
||
|
return d.DecodeMap()
|
||
|
}
|
||
|
|
||
|
// DecodeMapLen decodes map length. Length is -1 when map is nil.
|
||
|
func (d *Decoder) DecodeMapLen() (int, error) {
|
||
|
c, err := d.readCode()
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
if msgpcode.IsExt(c) {
|
||
|
if err = d.skipExtHeader(c); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
c, err = d.readCode()
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
}
|
||
|
return d.mapLen(c)
|
||
|
}
|
||
|
|
||
|
func (d *Decoder) mapLen(c byte) (int, error) {
|
||
|
if c == msgpcode.Nil {
|
||
|
return -1, nil
|
||
|
}
|
||
|
if c >= msgpcode.FixedMapLow && c <= msgpcode.FixedMapHigh {
|
||
|
return int(c & msgpcode.FixedMapMask), nil
|
||
|
}
|
||
|
if c == msgpcode.Map16 {
|
||
|
size, err := d.uint16()
|
||
|
return int(size), err
|
||
|
}
|
||
|
if c == msgpcode.Map32 {
|
||
|
size, err := d.uint32()
|
||
|
return int(size), err
|
||
|
}
|
||
|
return 0, unexpectedCodeError{code: c, hint: "map length"}
|
||
|
}
|
||
|
|
||
|
func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
|
||
|
mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string)
|
||
|
return d.decodeMapStringStringPtr(mptr)
|
||
|
}
|
||
|
|
||
|
func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error {
|
||
|
size, err := d.DecodeMapLen()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if size == -1 {
|
||
|
*ptr = nil
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
m := *ptr
|
||
|
if m == nil {
|
||
|
*ptr = make(map[string]string, min(size, maxMapSize))
|
||
|
m = *ptr
|
||
|
}
|
||
|
|
||
|
for i := 0; i < size; i++ {
|
||
|
mk, err := d.DecodeString()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
mv, err := d.DecodeString()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
m[mk] = mv
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
|
||
|
ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{})
|
||
|
return d.decodeMapStringInterfacePtr(ptr)
|
||
|
}
|
||
|
|
||
|
func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
|
||
|
m, err := d.DecodeMap()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
*ptr = m
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (d *Decoder) DecodeMap() (map[string]interface{}, error) {
|
||
|
n, err := d.DecodeMapLen()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if n == -1 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
m := make(map[string]interface{}, min(n, maxMapSize))
|
||
|
|
||
|
for i := 0; i < n; i++ {
|
||
|
mk, err := d.DecodeString()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
mv, err := d.decodeInterfaceCond()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
m[mk] = mv
|
||
|
}
|
||
|
|
||
|
return m, nil
|
||
|
}
|
||
|
|
||
|
func (d *Decoder) DecodeUntypedMap() (map[interface{}]interface{}, error) {
|
||
|
n, err := d.DecodeMapLen()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if n == -1 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
m := make(map[interface{}]interface{}, min(n, maxMapSize))
|
||
|
|
||
|
for i := 0; i < n; i++ {
|
||
|
mk, err := d.decodeInterfaceCond()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
mv, err := d.decodeInterfaceCond()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
m[mk] = mv
|
||
|
}
|
||
|
|
||
|
return m, nil
|
||
|
}
|
||
|
|
||
|
// DecodeTypedMap decodes a typed map. Typed map is a map that has a fixed type for keys and values.
|
||
|
// Key and value types may be different.
|
||
|
func (d *Decoder) DecodeTypedMap() (interface{}, error) {
|
||
|
n, err := d.DecodeMapLen()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if n <= 0 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
key, err := d.decodeInterfaceCond()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
value, err := d.decodeInterfaceCond()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
keyType := reflect.TypeOf(key)
|
||
|
valueType := reflect.TypeOf(value)
|
||
|
|
||
|
if !keyType.Comparable() {
|
||
|
return nil, fmt.Errorf("msgpack: unsupported map key: %s", keyType.String())
|
||
|
}
|
||
|
|
||
|
mapType := reflect.MapOf(keyType, valueType)
|
||
|
mapValue := reflect.MakeMap(mapType)
|
||
|
mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value))
|
||
|
|
||
|
n--
|
||
|
if err := d.decodeTypedMapValue(mapValue, n); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return mapValue.Interface(), nil
|
||
|
}
|
||
|
|
||
|
func (d *Decoder) decodeTypedMapValue(v reflect.Value, n int) error {
|
||
|
typ := v.Type()
|
||
|
keyType := typ.Key()
|
||
|
valueType := typ.Elem()
|
||
|
|
||
|
for i := 0; i < n; i++ {
|
||
|
mk := reflect.New(keyType).Elem()
|
||
|
if err := d.DecodeValue(mk); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
mv := reflect.New(valueType).Elem()
|
||
|
if err := d.DecodeValue(mv); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
v.SetMapIndex(mk, mv)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (d *Decoder) skipMap(c byte) error {
|
||
|
n, err := d.mapLen(c)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
for i := 0; i < n; i++ {
|
||
|
if err := d.Skip(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if err := d.Skip(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func decodeStructValue(d *Decoder, v reflect.Value) error {
|
||
|
c, err := d.readCode()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
n, err := d.mapLen(c)
|
||
|
if err == nil {
|
||
|
return d.decodeStruct(v, n)
|
||
|
}
|
||
|
|
||
|
var err2 error
|
||
|
n, err2 = d.arrayLen(c)
|
||
|
if err2 != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if n <= 0 {
|
||
|
v.Set(reflect.Zero(v.Type()))
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
fields := structs.Fields(v.Type(), d.structTag)
|
||
|
if n != len(fields.List) {
|
||
|
return errArrayStruct
|
||
|
}
|
||
|
|
||
|
for _, f := range fields.List {
|
||
|
if err := f.DecodeValue(d, v); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (d *Decoder) decodeStruct(v reflect.Value, n int) error {
|
||
|
if n == -1 {
|
||
|
v.Set(reflect.Zero(v.Type()))
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
fields := structs.Fields(v.Type(), d.structTag)
|
||
|
for i := 0; i < n; i++ {
|
||
|
name, err := d.decodeStringTemp()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if f := fields.Map[name]; f != nil {
|
||
|
if err := f.DecodeValue(d, v); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if d.flags&disallowUnknownFieldsFlag != 0 {
|
||
|
return fmt.Errorf("msgpack: unknown field %q", name)
|
||
|
}
|
||
|
if err := d.Skip(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|