]> Cypherpunks repositories - gostls13.git/commitdiff
gob: add support for maps.
authorRob Pike <r@golang.org>
Wed, 5 May 2010 23:46:39 +0000 (16:46 -0700)
committerRob Pike <r@golang.org>
Wed, 5 May 2010 23:46:39 +0000 (16:46 -0700)
Because maps are mostly a hidden type, they must be
implemented using reflection values and will not be as
efficient as arrays and slices.

R=rsc
CC=golang-dev
https://golang.org/cl/1127041

src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/encode.go
src/pkg/gob/encoder.go
src/pkg/gob/type.go
src/pkg/gob/type_test.go

index df82a5b6bce0339ee32bf038408451eeadc188c5..447b199cb5729b6360f6ff6fcd5ff059009cab4d 100644 (file)
@@ -572,6 +572,7 @@ func TestEndToEnd(t *testing.T) {
        s2 := "string2"
        type T1 struct {
                a, b, c int
+               m       map[string]*float
                n       *[3]float
                strs    *[2]string
                int64s  *[]int64
@@ -579,10 +580,13 @@ func TestEndToEnd(t *testing.T) {
                y       []byte
                t       *T2
        }
+       pi := 3.14159
+       e := 2.71828
        t1 := &T1{
                a:      17,
                b:      18,
                c:      -5,
+               m:      map[string]*float{"pi": &pi, "e": &e},
                n:      &[3]float{1.5, 2.5, 3.5},
                strs:   &[2]string{s1, s2},
                int64s: &[]int64{77, 89, 123412342134},
@@ -921,6 +925,7 @@ type IT0 struct {
        ignore_g string
        ignore_h []byte
        ignore_i *RT1
+       ignore_m map[string]int
        c        float
 }
 
@@ -937,6 +942,7 @@ func TestIgnoredFields(t *testing.T) {
        it0.ignore_g = "pay no attention"
        it0.ignore_h = []byte("to the curtain")
        it0.ignore_i = &RT1{3.1, "hi", 7, "hello"}
+       it0.ignore_m = map[string]int{"one": 1, "two": 2}
 
        b := new(bytes.Buffer)
        NewEncoder(b).Encode(it0)
index 3b14841afd1d34d4061d97d133b5da32f428c35b..fb1e99367008005320cdfd34c6c1dfc9f3292d4f 100644 (file)
@@ -447,6 +447,49 @@ func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp
        return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl)
 }
 
+func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
+       instr := &decInstr{op, 0, indir, 0, ovfl}
+       up := unsafe.Pointer(v.Addr())
+       if indir > 1 {
+               up = decIndirect(up, indir)
+       }
+       op(instr, state, up)
+       return v
+}
+
+func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error {
+       if indir > 0 {
+               up := unsafe.Pointer(p)
+               if *(*unsafe.Pointer)(up) == nil {
+                       // Allocate object.
+                       *(*unsafe.Pointer)(up) = unsafe.New(mtyp)
+               }
+               p = *(*uintptr)(up)
+       }
+       up := unsafe.Pointer(p)
+       if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime
+               // Allocate map.
+               *(*unsafe.Pointer)(up) = unsafe.Pointer(reflect.MakeMap(mtyp).Get())
+       }
+       // Maps cannot be accessed by moving addresses around the way
+       // that slices etc. can.  We must recover a full reflection value for
+       // the iteration.
+       v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue)
+       n := int(decodeUint(state))
+       for i := 0; i < n && state.err == nil; i++ {
+               key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl)
+               if state.err != nil {
+                       break
+               }
+               elem := decodeIntoValue(state, elemOp, elemIndir, reflect.MakeZero(mtyp.Elem()), ovfl)
+               if state.err != nil {
+                       break
+               }
+               v.SetElem(key, elem)
+       }
+       return state.err
+}
+
 func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error {
        instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
        for i := 0; i < length && state.err == nil; i++ {
@@ -462,6 +505,18 @@ func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error {
        return ignoreArrayHelper(state, elemOp, length)
 }
 
+func ignoreMap(state *decodeState, keyOp, elemOp decOp) os.Error {
+       n := int(decodeUint(state))
+       keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")}
+       elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
+       for i := 0; i < n && state.err == nil; i++ {
+               keyOp(keyInstr, state, nil)
+               elemOp(elemInstr, state, nil)
+       }
+       return state.err
+}
+
+
 func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error {
        n := int(uintptr(decodeUint(state)))
        if indir > 0 {
@@ -517,17 +572,25 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
        if !ok {
                // Special cases
                switch t := typ.(type) {
-               case *reflect.SliceType:
+               case *reflect.ArrayType:
                        name = "element of " + name
-                       if _, ok := t.Elem().(*reflect.Uint8Type); ok {
-                               op = decUint8Array
-                               break
+                       elemId := dec.wireType[wireId].arrayT.Elem
+                       elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
+                       if err != nil {
+                               return nil, 0, err
                        }
-                       var elemId typeId
-                       if tt, ok := builtinIdToType[wireId]; ok {
-                               elemId = tt.(*sliceType).Elem
-                       } else {
-                               elemId = dec.wireType[wireId].slice.Elem
+                       ovfl := overflow(name)
+                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                               state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
+                       }
+
+               case *reflect.MapType:
+                       name = "element of " + name
+                       keyId := dec.wireType[wireId].mapT.Key
+                       elemId := dec.wireType[wireId].mapT.Elem
+                       keyOp, keyIndir, err := dec.decOpFor(keyId, t.Key(), name)
+                       if err != nil {
+                               return nil, 0, err
                        }
                        elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
                        if err != nil {
@@ -535,19 +598,32 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
                        }
                        ovfl := overflow(name)
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
+                               up := unsafe.Pointer(p)
+                               if indir > 1 {
+                                       up = decIndirect(up, indir)
+                               }
+                               state.err = decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl)
                        }
 
-               case *reflect.ArrayType:
+               case *reflect.SliceType:
                        name = "element of " + name
-                       elemId := dec.wireType[wireId].array.Elem
+                       if _, ok := t.Elem().(*reflect.Uint8Type); ok {
+                               op = decUint8Array
+                               break
+                       }
+                       var elemId typeId
+                       if tt, ok := builtinIdToType[wireId]; ok {
+                               elemId = tt.(*sliceType).Elem
+                       } else {
+                               elemId = dec.wireType[wireId].sliceT.Elem
+                       }
                        elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name)
                        if err != nil {
                                return nil, 0, err
                        }
                        ovfl := overflow(name)
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
+                               state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
                        }
 
                case *reflect.StructType:
@@ -575,18 +651,33 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
                // Special cases
                wire := dec.wireType[wireId]
                switch {
-               case wire.array != nil:
-                       elemId := wire.array.Elem
+               case wire.arrayT != nil:
+                       elemId := wire.arrayT.Elem
                        elemOp, err := dec.decIgnoreOpFor(elemId)
                        if err != nil {
                                return nil, err
                        }
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = ignoreArray(state, elemOp, wire.array.Len)
+                               state.err = ignoreArray(state, elemOp, wire.arrayT.Len)
                        }
 
-               case wire.slice != nil:
-                       elemId := wire.slice.Elem
+               case wire.mapT != nil:
+                       keyId := dec.wireType[wireId].mapT.Key
+                       elemId := dec.wireType[wireId].mapT.Elem
+                       keyOp, err := dec.decIgnoreOpFor(keyId)
+                       if err != nil {
+                               return nil, err
+                       }
+                       elemOp, err := dec.decIgnoreOpFor(elemId)
+                       if err != nil {
+                               return nil, err
+                       }
+                       op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+                               state.err = ignoreMap(state, keyOp, elemOp)
+                       }
+
+               case wire.sliceT != nil:
+                       elemId := wire.sliceT.Elem
                        elemOp, err := dec.decIgnoreOpFor(elemId)
                        if err != nil {
                                return nil, err
@@ -595,7 +686,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
                                state.err = ignoreSlice(state, elemOp)
                        }
 
-               case wire.strct != nil:
+               case wire.structT != nil:
                        // Generate a closure that calls out to the engine for the nested type.
                        enginePtr, err := dec.getIgnoreEnginePtr(wireId)
                        if err != nil {
@@ -640,11 +731,18 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
                return fw == tString
        case *reflect.ArrayType:
                wire, ok := dec.wireType[fw]
-               if !ok || wire.array == nil {
+               if !ok || wire.arrayT == nil {
+                       return false
+               }
+               array := wire.arrayT
+               return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem)
+       case *reflect.MapType:
+               wire, ok := dec.wireType[fw]
+               if !ok || wire.mapT == nil {
                        return false
                }
-               array := wire.array
-               return ok && t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem)
+               mapType := wire.mapT
+               return dec.compatibleType(t.Key(), mapType.Key) && dec.compatibleType(t.Elem(), mapType.Elem)
        case *reflect.SliceType:
                // Is it an array of bytes?
                et := t.Elem()
@@ -656,7 +754,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
                if tt, ok := builtinIdToType[fw]; ok {
                        sw = tt.(*sliceType)
                } else {
-                       sw = dec.wireType[fw].slice
+                       sw = dec.wireType[fw].sliceT
                }
                elem, _ := indirect(t.Elem())
                return sw != nil && dec.compatibleType(elem, sw.Elem)
@@ -677,7 +775,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
                if !ok1 || !ok2 {
                        return nil, errNotStruct
                }
-               wireStruct = w.strct
+               wireStruct = w.structT
        }
        engine = new(decEngine)
        engine.instr = make([]decInstr, len(wireStruct.field))
@@ -760,7 +858,7 @@ func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
                return err
        }
        engine := *enginePtr
-       if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].strct.field) > 0 {
+       if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 {
                name := rt.Name()
                return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
        }
index 195d6c647056e26120264098dc5a2ed18662c168..fbea891b98855173a5b1187ee000ac0f211c87b2 100644 (file)
@@ -22,7 +22,7 @@ const uint64Size = unsafe.Sizeof(uint64(0))
 type encoderState struct {
        b        *bytes.Buffer
        err      os.Error             // error encountered during encoding.
-       inArray  bool                 // encoding an array element
+       inArray  bool                 // encoding an array element or map key/value pair
        fieldnum int                  // the last field number written.
        buf      [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation.
 }
@@ -297,7 +297,7 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
        return state.err
 }
 
-func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length int, elemIndir int) os.Error {
+func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) os.Error {
        state := new(encoderState)
        state.b = b
        state.fieldnum = -1
@@ -319,6 +319,39 @@ func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length i
        return state.err
 }
 
+func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) {
+       for i := 0; i < indir && v != nil; i++ {
+               v = reflect.Indirect(v)
+       }
+       if v == nil {
+               state.err = os.ErrorString("gob: encodeMap: nil element")
+               return
+       }
+       op(nil, state, unsafe.Pointer(v.Addr()))
+}
+
+func encodeMap(b *bytes.Buffer, rt reflect.Type, p uintptr, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error {
+       state := new(encoderState)
+       state.b = b
+       state.fieldnum = -1
+       state.inArray = true
+       // Maps cannot be accessed by moving addresses around the way
+       // that slices etc. can.  We must recover a full reflection value for
+       // the iteration.
+       v := reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer((p))))
+       mv := reflect.Indirect(v).(*reflect.MapValue)
+       keys := mv.Keys()
+       encodeUint(state, uint64(len(keys)))
+       for _, key := range keys {
+               if state.err != nil {
+                       break
+               }
+               encodeReflectValue(state, key, keyOp, keyIndir)
+               encodeReflectValue(state, mv.Elem(key), elemOp, elemIndir)
+       }
+       return state.err
+}
+
 var encOpMap = map[reflect.Type]encOp{
        valueKind(false):      encBool,
        valueKind(int(0)):     encInt,
@@ -344,7 +377,6 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
        typ, indir := indirect(rt)
        op, ok := encOpMap[reflect.Typeof(typ)]
        if !ok {
-               typ, _ := indirect(rt)
                // Special cases
                switch t := typ.(type) {
                case *reflect.SliceType:
@@ -363,7 +395,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                                        return
                                }
                                state.update(i)
-                               state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), int(slice.Len), indir)
+                               state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len))
                        }
                case *reflect.ArrayType:
                        // True arrays have size in the type.
@@ -373,7 +405,20 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                        }
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
                                state.update(i)
-                               state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), t.Len(), indir)
+                               state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
+                       }
+               case *reflect.MapType:
+                       keyOp, keyIndir, err := encOpFor(t.Key())
+                       if err != nil {
+                               return nil, 0, err
+                       }
+                       elemOp, elemIndir, err := encOpFor(t.Elem())
+                       if err != nil {
+                               return nil, 0, err
+                       }
+                       op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
+                               state.update(i)
+                               state.err = encodeMap(state.b, typ, uintptr(p), keyOp, elemOp, keyIndir, elemIndir)
                        }
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
index 8ba5031384179ceb659af0d8420c89158e03610f..d65a710802f689cef2efe5406533a61881e03d9f 100644 (file)
@@ -71,9 +71,8 @@
        Structs, arrays and slices are also supported.  Strings and arrays of bytes are
        supported with a special, efficient representation (see below).
 
-       Maps are not supported yet, but they will be.  Interfaces, functions, and channels
-       cannot be sent in a gob.  Attempting to encode a value that contains one will
-       fail.
+       Interfaces, functions, and channels cannot be sent in a gob.  Attempting
+       to encode a value that contains one will fail.
 
        The rest of this comment documents the encoding, details that are not important
        for most users.  Details are presented bottom-up.
@@ -263,10 +262,13 @@ func (enc *Encoder) sendType(origt reflect.Type) {
        case *reflect.ArrayType:
                // arrays must be sent so we know their lengths and element types.
                break
+       case *reflect.MapType:
+               // maps must be sent so we know their lengths and key/value types.
+               break
        case *reflect.StructType:
                // structs must be sent so we know their fields.
                break
-       case *reflect.ChanType, *reflect.FuncType, *reflect.MapType, *reflect.InterfaceType:
+       case *reflect.ChanType, *reflect.FuncType, *reflect.InterfaceType:
                // Probably a bad field in a struct.
                enc.badType(rt)
                return
index 2a178af04b013711d7ae23d5e76d118b739152a8..78793ba44749dc3aec3147a5bff2e5da6e907c20 100644 (file)
@@ -142,6 +142,31 @@ func (a *arrayType) safeString(seen map[typeId]bool) string {
 
 func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
 
+// Map type
+type mapType struct {
+       commonType
+       Key  typeId
+       Elem typeId
+}
+
+func newMapType(name string, key, elem gobType) *mapType {
+       m := &mapType{commonType{name: name}, key.id(), elem.id()}
+       setTypeId(m)
+       return m
+}
+
+func (m *mapType) safeString(seen map[typeId]bool) string {
+       if seen[m._id] {
+               return m.name
+       }
+       seen[m._id] = true
+       key := m.Key.gobType().safeString(seen)
+       elem := m.Elem.gobType().safeString(seen)
+       return fmt.Sprintf("map[%s]%s", key, elem)
+}
+
+func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) }
+
 // Slice type
 type sliceType struct {
        commonType
@@ -239,6 +264,17 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
                }
                return newArrayType(name, gt, t.Len()), nil
 
+       case *reflect.MapType:
+               kt, err := getType("", t.Key())
+               if err != nil {
+                       return nil, err
+               }
+               vt, err := getType("", t.Elem())
+               if err != nil {
+                       return nil, err
+               }
+               return newMapType(name, kt, vt), nil
+
        case *reflect.SliceType:
                // []byte == []uint8 is a special case
                if _, ok := t.Elem().(*reflect.Uint8Type); ok {
@@ -330,16 +366,18 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
 // using the gob rules for sending a structure, except that we assume the
 // ids for wireType and structType are known.  The relevant pieces
 // are built in encode.go's init() function.
-
+// To maintain binary compatibility, if you extend this type, always put
+// the new fields last.
 type wireType struct {
-       array *arrayType
-       slice *sliceType
-       strct *structType
+       arrayT  *arrayType
+       sliceT  *sliceType
+       structT *structType
+       mapT    *mapType
 }
 
 func (w *wireType) name() string {
-       if w.strct != nil {
-               return w.strct.name
+       if w.structT != nil {
+               return w.structT.name
        }
        return "unknown"
 }
@@ -370,14 +408,16 @@ func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) {
                t := info.id.gobType()
                switch typ := rt.(type) {
                case *reflect.ArrayType:
-                       info.wire = &wireType{array: t.(*arrayType)}
+                       info.wire = &wireType{arrayT: t.(*arrayType)}
+               case *reflect.MapType:
+                       info.wire = &wireType{mapT: t.(*mapType)}
                case *reflect.SliceType:
                        // []byte == []uint8 is a special case handled separately
                        if _, ok := typ.Elem().(*reflect.Uint8Type); !ok {
-                               info.wire = &wireType{slice: t.(*sliceType)}
+                               info.wire = &wireType{sliceT: t.(*sliceType)}
                        }
                case *reflect.StructType:
-                       info.wire = &wireType{strct: t.(*structType)}
+                       info.wire = &wireType{structT: t.(*structType)}
                }
                typeInfoMap[rt] = info
        }
index 3d4871f1db8529b9de9e1239c681187095d0b728..6acfa7135e6c673b52849b73358b8d313e9a01bb 100644 (file)
@@ -105,6 +105,26 @@ func TestSliceType(t *testing.T) {
        }
 }
 
+func TestMapType(t *testing.T) {
+       var m map[string]int
+       mapStringInt := getTypeUnlocked("map", reflect.Typeof(m))
+       var newm map[string]int
+       newMapStringInt := getTypeUnlocked("map1", reflect.Typeof(newm))
+       if mapStringInt != newMapStringInt {
+               t.Errorf("second registration of map[string]int creates new type")
+       }
+       var b map[string]bool
+       mapStringBool := getTypeUnlocked("", reflect.Typeof(b))
+       if mapStringBool == mapStringInt {
+               t.Errorf("registration of map[string]bool creates same type as map[string]int")
+       }
+       str := mapStringBool.string()
+       expected := "map[string]bool"
+       if str != expected {
+               t.Errorf("map printed as %q; expected %q", str, expected)
+       }
+}
+
 type Bar struct {
        x string
 }