]> Cypherpunks repositories - gostls13.git/commitdiff
gob: compute information about a user's type once.
authorRob Pike <r@golang.org>
Tue, 22 Feb 2011 20:31:57 +0000 (12:31 -0800)
committerRob Pike <r@golang.org>
Tue, 22 Feb 2011 20:31:57 +0000 (12:31 -0800)
Other than maybe cleaning the code up a bit, this has
little practical effect for now, but lays the foundation
for remembering the method set of a type, which can
be expensive.

R=rsc
CC=golang-dev
https://golang.org/cl/4193041

src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/encode.go
src/pkg/gob/encoder.go
src/pkg/gob/type.go

index fe1f60ba75a2825820388f12a11eacfb780d975a..c09736221ed05a0b0c277ca6df7ce88243ed58f7 100644 (file)
@@ -984,7 +984,7 @@ func TestInvalidField(t *testing.T) {
        var bad0 Bad0
        bad0.ch = make(chan int)
        b := new(bytes.Buffer)
-       err := nilEncoder.encode(b, reflect.NewValue(&bad0))
+       err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
        if err == nil {
                t.Error("expected error; got none")
        } else if strings.Index(err.String(), "type") < 0 {
index 9667f6157ec1ce306c154e86593503054599c634..d3f87144dac274e20501f3b3d39bef19a87e2cff 100644 (file)
@@ -409,9 +409,9 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
        return *(*uintptr)(up)
 }
 
-func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr, indir int) (err os.Error) {
+func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) {
        defer catchError(&err)
-       p = allocate(rtyp, p, indir)
+       p = allocate(ut.base, p, ut.indir)
        state := newDecodeState(dec, &dec.buf)
        state.fieldnum = singletonField
        basep := p
@@ -428,9 +428,13 @@ func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr
        return nil
 }
 
-func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, p uintptr, indir int) (err os.Error) {
+// Indir is for the value, not the type.  At the time of the call it may
+// differ from ut.indir, which was computed when the engine was built.
+// This state cannot arise for decodeSingle, which is called directly
+// from the user's value, not from the innards of an engine.
+func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) {
        defer catchError(&err)
-       p = allocate(rtyp, p, indir)
+       p = allocate(ut.base.(*reflect.StructType), p, indir)
        state := newDecodeState(dec, &dec.buf)
        state.fieldnum = -1
        basep := p
@@ -702,7 +706,9 @@ var decIgnoreOpMap = map[typeId]decOp{
 // Return the decoding op for the base type under rt and
 // the indirection count to reach it.
 func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) {
-       typ, indir := indirect(rt)
+       ut := userType(rt)
+       typ := ut.base
+       indir := ut.indir
        var op decOp
        k := typ.Kind()
        if int(k) < len(decOpMap) {
@@ -757,8 +763,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
                                error(err)
                        }
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               // indirect through enginePtr to delay evaluation for recursive structs
-                               err = dec.decodeStruct(*enginePtr, t, uintptr(p), i.indir)
+                               // indirect through enginePtr to delay evaluation for recursive structs.
+                               err = dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir)
                                if err != nil {
                                        error(err)
                                }
@@ -837,7 +843,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
 // Answers the question for basic types, arrays, and slices.
 // Structs are considered ok; fields will be checked later.
 func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
-       fr, _ = indirect(fr)
+       fr = userType(fr).base
        switch t := fr.(type) {
        default:
                // map, chan, etc: cannot handle.
@@ -882,7 +888,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
                } else {
                        sw = dec.wireType[fw].SliceT
                }
-               elem, _ := indirect(t.Elem())
+               elem := userType(t.Elem()).base
                return sw != nil && dec.compatibleType(elem, sw.Elem)
        case *reflect.StructType:
                return true
@@ -1026,20 +1032,22 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error {
                return dec.decodeIgnoredValue(wireId)
        }
        // Dereference down to the underlying struct type.
-       rt, indir := indirect(val.Type())
-       enginePtr, err := dec.getDecEnginePtr(wireId, rt)
+       ut := userType(val.Type())
+       base := ut.base
+       indir := ut.indir
+       enginePtr, err := dec.getDecEnginePtr(wireId, base)
        if err != nil {
                return err
        }
        engine := *enginePtr
-       if st, ok := rt.(*reflect.StructType); ok {
+       if st, ok := base.(*reflect.StructType); ok {
                if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 {
-                       name := rt.Name()
+                       name := base.Name()
                        return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
                }
-               return dec.decodeStruct(engine, st, uintptr(val.Addr()), indir)
+               return dec.decodeStruct(engine, ut, uintptr(val.Addr()), indir)
        }
-       return dec.decodeSingle(engine, rt, uintptr(val.Addr()), indir)
+       return dec.decodeSingle(engine, ut, uintptr(val.Addr()))
 }
 
 func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error {
index 2e5ba2487c0280f403e4bdda9d8eb88129e8b653..c5570409b4fed2c17ad0264edf727a267eca52a0 100644 (file)
@@ -384,10 +384,10 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
                return
        }
 
-       typ, _ := indirect(iv.Elem().Type())
-       name, ok := concreteTypeToName[typ]
+       ut := userType(iv.Elem().Type())
+       name, ok := concreteTypeToName[ut.base]
        if !ok {
-               errorf("gob: type not registered for interface: %s", typ)
+               errorf("gob: type not registered for interface: %s", ut.base)
        }
        // Send the name.
        state.encodeUint(uint64(len(name)))
@@ -396,14 +396,14 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
                error(err)
        }
        // Define the type id if necessary.
-       enc.sendTypeDescriptor(enc.writer(), state, typ)
+       enc.sendTypeDescriptor(enc.writer(), state, ut)
        // Send the type id.
-       enc.sendTypeId(state, typ)
+       enc.sendTypeId(state, ut)
        // Encode the value into a new buffer.  Any nested type definitions
        // should be written to b, before the encoded value.
        enc.pushWriter(b)
        data := new(bytes.Buffer)
-       err = enc.encode(data, iv.Elem())
+       err = enc.encode(data, iv.Elem(), ut)
        if err != nil {
                error(err)
        }
@@ -437,7 +437,9 @@ var encOpMap = []encOp{
 // Return the encoding op for the base type under rt and
 // the indirection count to reach it.
 func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
-       typ, indir := indirect(rt)
+       ut := userType(rt)
+       typ := ut.base
+       indir := ut.indir
        var op encOp
        k := typ.Kind()
        if int(k) < len(encOpMap) {
@@ -559,14 +561,12 @@ func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine {
        return enc.getEncEngine(rt)
 }
 
-func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value) (err os.Error) {
+func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) (err os.Error) {
        defer catchError(&err)
-       // Dereference down to the underlying object.
-       rt, indir := indirect(value.Type())
-       for i := 0; i < indir; i++ {
+       for i := 0; i < ut.indir; i++ {
                value = reflect.Indirect(value)
        }
-       engine := enc.lockAndGetEncEngine(rt)
+       engine := enc.lockAndGetEncEngine(ut.base)
        if value.Type().Kind() == reflect.Struct {
                enc.encodeStruct(b, engine, value.Addr())
        } else {
index 29ba44057ec0900d7a1ad798e33a9a5c1f9725a4..1419a278445f4fcad4fb50e7881ced90a58d171f 100644 (file)
@@ -80,7 +80,8 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
 
 func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
        // Drill down to the base type.
-       rt, _ := indirect(origt)
+       ut := userType(origt)
+       rt := ut.base
 
        switch rt := rt.(type) {
        default:
@@ -125,7 +126,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ
        // Id:
        state.encodeInt(-int64(info.id))
        // Type:
-       enc.encode(state.b, reflect.NewValue(info.wire))
+       enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
        enc.writeMessage(w, state.b)
        if enc.err != nil {
                return
@@ -153,15 +154,16 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
        return enc.EncodeValue(reflect.NewValue(e))
 }
 
-// sendTypeId makes sure the remote side knows about this type.
+// sendTypeDescriptor makes sure the remote side knows about this type.
 // It will send a descriptor if this is the first time the type has been
 // sent.
-func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt reflect.Type) {
+func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
        // Make sure the type is known to the other side.
-       // First, have we already sent this type?
-       if _, alreadySent := enc.sent[rt]; !alreadySent {
+       // First, have we already sent this (base) type?
+       base := ut.base
+       if _, alreadySent := enc.sent[base]; !alreadySent {
                // No, so send it.
-               sent := enc.sendType(w, state, rt)
+               sent := enc.sendType(w, state, base)
                if enc.err != nil {
                        return
                }
@@ -170,21 +172,21 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt refl
                // need to send the type info but we do need to update enc.sent.
                if !sent {
                        typeLock.Lock()
-                       info, err := getTypeInfo(rt)
+                       info, err := getTypeInfo(base)
                        typeLock.Unlock()
                        if err != nil {
                                enc.setError(err)
                                return
                        }
-                       enc.sent[rt] = info.id
+                       enc.sent[base] = info.id
                }
        }
 }
 
 // sendTypeId sends the id, which must have already been defined.
-func (enc *Encoder) sendTypeId(state *encoderState, rt reflect.Type) {
+func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) {
        // Identify the type of this top-level value.
-       state.encodeInt(int64(enc.sent[rt]))
+       state.encodeInt(int64(enc.sent[ut.base]))
 }
 
 // EncodeValue transmits the data item represented by the reflection value,
@@ -199,18 +201,18 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
        enc.w = enc.w[0:1]
 
        enc.err = nil
-       rt, _ := indirect(value.Type())
+       ut := userType(value.Type())
 
        state := newEncoderState(enc, new(bytes.Buffer))
 
-       enc.sendTypeDescriptor(enc.writer(), state, rt)
-       enc.sendTypeId(state, rt)
+       enc.sendTypeDescriptor(enc.writer(), state, ut)
+       enc.sendTypeId(state, ut)
        if enc.err != nil {
                return enc.err
        }
 
        // Encode the object.
-       err := enc.encode(state.b, value)
+       err := enc.encode(state.b, value, ut)
        if err != nil {
                enc.setError(err)
        } else {
index f613f6e8a9edbd196c1194729aa27ff881838877..c9c116abf830acd8782d9fb01453ff5e87a5cbd0 100644 (file)
@@ -11,12 +11,43 @@ import (
        "sync"
 )
 
-// Reflection types are themselves interface values holding structs
-// describing the type.  Each type has a different struct so that struct can
-// be the kind.  For example, if typ is the reflect type for an int8, typ is
-// a pointer to a reflect.Int8Type struct; if typ is the reflect type for a
-// function, typ is a pointer to a reflect.FuncType struct; we use the type
-// of that pointer as the kind.
+// userTypeInfo stores the information associated with a type the user has handed
+// to the package.  It's computed once and stored in a map keyed by reflection
+// type.
+type userTypeInfo struct {
+       user  reflect.Type // the type the user handed us
+       base  reflect.Type // the base type after all indirections
+       indir int          // number of indirections to reach the base type
+}
+
+var (
+       // Protected by an RWMutex because we read it a lot and write
+       // it only when we see a new type, typically when compiling.
+       userTypeLock  sync.RWMutex
+       userTypeCache = make(map[reflect.Type]*userTypeInfo)
+)
+
+// userType returns, and saves, the information associated with user-provided type rt
+func userType(rt reflect.Type) *userTypeInfo {
+       userTypeLock.RLock()
+       ut := userTypeCache[rt]
+       userTypeLock.RUnlock()
+       if ut != nil {
+               return ut
+       }
+       // Now set the value under the write lock.
+       userTypeLock.Lock()
+       defer userTypeLock.Unlock()
+       if ut = userTypeCache[rt]; ut != nil {
+               // Lost the race; not a problem.
+               return ut
+       }
+       ut = new(userTypeInfo)
+       ut.user = rt
+       ut.base, ut.indir = indirect(rt)
+       userTypeCache[rt] = ut
+       return ut
+}
 
 // A typeId represents a gob Type as an integer that can be passed on the wire.
 // Internally, typeIds are used as keys to a map to recover the underlying type info.
@@ -110,6 +141,7 @@ var (
 
 // Predefined because it's needed by the Decoder
 var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id
+var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType)
 
 func init() {
        // Some magic numbers to make sure there are no surprises.
@@ -133,6 +165,7 @@ func init() {
        }
        nextId = firstUserId
        registerBasics()
+       wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil)))
 }
 
 // Array type
@@ -317,10 +350,10 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
                field := make([]*fieldType, t.NumField())
                for i := 0; i < t.NumField(); i++ {
                        f := t.Field(i)
-                       typ, _ := indirect(f.Type)
+                       typ := userType(f.Type).base
                        tname := typ.Name()
                        if tname == "" {
-                               t, _ := indirect(f.Type)
+                               t := userType(f.Type).base
                                tname = t.String()
                        }
                        gt, err := getType(tname, f.Type)
@@ -341,7 +374,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
 // getType returns the Gob type describing the given reflect.Type.
 // typeLock must be held.
 func getType(name string, rt reflect.Type) (gobType, os.Error) {
-       rt, _ = indirect(rt)
+       rt = userType(rt).base
        typ, present := types[rt]
        if present {
                return typ, nil
@@ -371,6 +404,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
        types[rt] = typ
        setTypeId(typ)
        checkId(expect, nextId)
+       userType(rt) // might as well cache it now
        return nextId
 }
 
@@ -473,18 +507,18 @@ func RegisterName(name string, value interface{}) {
                // reserved for nil
                panic("attempt to register empty name")
        }
-       rt, _ := indirect(reflect.Typeof(value))
+       base := userType(reflect.Typeof(value)).base
        // Check for incompatible duplicates.
-       if t, ok := nameToConcreteType[name]; ok && t != rt {
+       if t, ok := nameToConcreteType[name]; ok && t != base {
                panic("gob: registering duplicate types for " + name)
        }
-       if n, ok := concreteTypeToName[rt]; ok && n != name {
-               panic("gob: registering duplicate names for " + rt.String())
+       if n, ok := concreteTypeToName[base]; ok && n != name {
+               panic("gob: registering duplicate names for " + base.String())
        }
        // Store the name and type provided by the user....
        nameToConcreteType[name] = reflect.Typeof(value)
        // but the flattened type in the type table, since that's what decode needs.
-       concreteTypeToName[rt] = name
+       concreteTypeToName[base] = name
 }
 
 // Register records a type, identified by a value for that type, under its