]> Cypherpunks repositories - gostls13.git/commitdiff
gob: make nested interfaces work.
authorRob Pike <r@golang.org>
Fri, 28 Jan 2011 18:53:42 +0000 (10:53 -0800)
committerRob Pike <r@golang.org>
Fri, 28 Jan 2011 18:53:42 +0000 (10:53 -0800)
Also clean up the code, make it more regular.

Fixes #1416.

R=rsc
CC=golang-dev
https://golang.org/cl/3985047

src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/decoder.go
src/pkg/gob/encode.go
src/pkg/gob/encoder.go
src/pkg/gob/encoder_test.go

index eb1ff5c616d01174c77c379d7efeec2d461824ad..fe1f60ba75a2825820388f12a11eacfb780d975a 100644 (file)
@@ -58,7 +58,7 @@ func TestUintCodec(t *testing.T) {
                        t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes())
                }
        }
-       decState := newDecodeState(nil, &b)
+       decState := newDecodeState(nil, b)
        for u := uint64(0); ; u = (u + 1) * 7 {
                b.Reset()
                encState.encodeUint(u)
@@ -77,7 +77,7 @@ func verifyInt(i int64, t *testing.T) {
        var b = new(bytes.Buffer)
        encState := newEncoderState(nil, b)
        encState.encodeInt(i)
-       decState := newDecodeState(nil, &b)
+       decState := newDecodeState(nil, b)
        decState.buf = make([]byte, 8)
        j := decState.decodeInt()
        if i != j {
@@ -315,7 +315,7 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un
 
 func newDecodeStateFromData(data []byte) *decodeState {
        b := bytes.NewBuffer(data)
-       state := newDecodeState(nil, &b)
+       state := newDecodeState(nil, b)
        state.fieldnum = -1
        return state
 }
@@ -1162,7 +1162,6 @@ func TestInterface(t *testing.T) {
                        }
                }
        }
-
 }
 
 // A struct with all basic types, stored in interfaces.
@@ -1182,7 +1181,7 @@ func TestInterfaceBasic(t *testing.T) {
                int(1), int8(1), int16(1), int32(1), int64(1),
                uint(1), uint8(1), uint16(1), uint32(1), uint64(1),
                float32(1), 1.0,
-               complex64(0i), complex128(0i),
+               complex64(1i), complex128(1i),
                true,
                "hello",
                []byte("sailor"),
index 73a926961212686eee141baee2f465ddf2af56c5..db8b968700a8d5c3e4550582ae276dc7a8ab5c1e 100644 (file)
@@ -30,15 +30,17 @@ type decodeState struct {
        dec *Decoder
        // The buffer is stored with an extra indirection because it may be replaced
        // if we load a type during decode (when reading an interface value).
-       b        **bytes.Buffer
+       b        *bytes.Buffer
        fieldnum int // the last field number read.
        buf      []byte
 }
 
-func newDecodeState(dec *Decoder, b **bytes.Buffer) *decodeState {
+// We pass the bytes.Buffer separately for easier testing of the infrastructure
+// without requiring a full Decoder.
+func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decodeState {
        d := new(decodeState)
        d.dec = dec
-       d.b = b
+       d.b = buf
        d.buf = make([]byte, uint64Size)
        return d
 }
@@ -407,10 +409,10 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
        return *(*uintptr)(up)
 }
 
-func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) (err os.Error) {
+func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr, indir int) (err os.Error) {
        defer catchError(&err)
        p = allocate(rtyp, p, indir)
-       state := newDecodeState(dec, b)
+       state := newDecodeState(dec, &dec.buf)
        state.fieldnum = singletonField
        basep := p
        delta := int(state.decodeUint())
@@ -426,10 +428,10 @@ func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes
        return nil
 }
 
-func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) (err os.Error) {
+func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, p uintptr, indir int) (err os.Error) {
        defer catchError(&err)
        p = allocate(rtyp, p, indir)
-       state := newDecodeState(dec, b)
+       state := newDecodeState(dec, &dec.buf)
        state.fieldnum = -1
        basep := p
        for state.b.Len() > 0 {
@@ -456,9 +458,9 @@ func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b
        return nil
 }
 
-func (dec *Decoder) ignoreStruct(engine *decEngine, b **bytes.Buffer) (err os.Error) {
+func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
        defer catchError(&err)
-       state := newDecodeState(dec, b)
+       state := newDecodeState(dec, &dec.buf)
        state.fieldnum = -1
        for state.b.Len() > 0 {
                delta := int(state.decodeUint())
@@ -614,9 +616,17 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt
        if !ok {
                errorf("gob: name not registered for interface: %q", name)
        }
+       // Read the type id of the concrete value.
+       concreteId := dec.decodeTypeSequence(true)
+       if concreteId < 0 {
+               error(dec.err)
+       }
+       // Byte count of value is next; we don't care what it is (it's there
+       // in case we want to ignore the value by skipping it completely).
+       state.decodeUint()
        // Read the concrete value.
        value := reflect.MakeZero(typ)
-       dec.decodeValueFromBuffer(value, false, true)
+       dec.decodeValue(concreteId, value)
        if dec.err != nil {
                error(dec.err)
        }
@@ -639,10 +649,12 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
        if err != nil {
                error(err)
        }
-       dec.decodeValueFromBuffer(nil, true, true)
-       if dec.err != nil {
-               error(err)
+       id := dec.decodeTypeSequence(true)
+       if id < 0 {
+               error(dec.err)
        }
+       // At this point, the decoder buffer contains the value. Just toss it.
+       state.b.Reset()
 }
 
 // Index by Go types.
@@ -733,7 +745,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
                        }
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
                                // indirect through enginePtr to delay evaluation for recursive structs
-                               err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir)
+                               err = dec.decodeStruct(*enginePtr, t, uintptr(p), i.indir)
                                if err != nil {
                                        error(err)
                                }
@@ -798,7 +810,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
                        }
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
                                // indirect through enginePtr to delay evaluation for recursive structs
-                               state.dec.ignoreStruct(*enginePtr, state.b)
+                               state.dec.ignoreStruct(*enginePtr)
                        }
                }
        }
@@ -907,7 +919,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
        if t, ok := builtinIdToType[remoteId]; ok {
                wireStruct, _ = t.(*structType)
        } else {
-               wireStruct = dec.wireType[remoteId].StructT
+               wire := dec.wireType[remoteId]
+               if wire == nil {
+                       error(errBadType)
+               }
+               wireStruct = wire.StructT
        }
        if wireStruct == nil {
                errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String())
@@ -976,7 +992,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
        return
 }
 
-func (dec *Decoder) decode(wireId typeId, val reflect.Value) os.Error {
+func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error {
        // Dereference down to the underlying struct type.
        rt, indir := indirect(val.Type())
        enginePtr, err := dec.getDecEnginePtr(wireId, rt)
@@ -989,9 +1005,9 @@ func (dec *Decoder) decode(wireId typeId, val reflect.Value) os.Error {
                        name := rt.Name()
                        return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
                }
-               return dec.decodeStruct(engine, st, dec.state.b, uintptr(val.Addr()), indir)
+               return dec.decodeStruct(engine, st, uintptr(val.Addr()), indir)
        }
-       return dec.decodeSingle(engine, rt, dec.state.b, uintptr(val.Addr()), indir)
+       return dec.decodeSingle(engine, rt, uintptr(val.Addr()), indir)
 }
 
 func init() {
index 10c72c37f586a9e46edd8ff12e306e15d77cd4ba..7527c5f1ffc2f7c4b0da5937d953105adb4f94bc 100644 (file)
@@ -17,14 +17,13 @@ import (
 type Decoder struct {
        mutex        sync.Mutex                              // each item must be received atomically
        r            io.Reader                               // source of the data
+       buf          bytes.Buffer                            // buffer for more efficient i/o from r
        wireType     map[typeId]*wireType                    // map from remote ID to local description
        decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
        ignorerCache map[typeId]**decEngine                  // ditto for ignored objects
-       state        *decodeState                            // reads data from in-memory buffer
        countState   *decodeState                            // reads counts from wire
-       buf          []byte
-       countBuf     [9]byte // counts may be uint64s (unlikely!), require 9 bytes
-       byteBuffer   *bytes.Buffer
+       countBuf     []byte                                  // used for decoding integers while parsing messages
+       tmp          []byte                                  // temporary storage for i/o; saves reallocating
        err          os.Error
 }
 
@@ -33,116 +32,138 @@ func NewDecoder(r io.Reader) *Decoder {
        dec := new(Decoder)
        dec.r = r
        dec.wireType = make(map[typeId]*wireType)
-       dec.state = newDecodeState(dec, &dec.byteBuffer) // buffer set in Decode()
        dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
        dec.ignorerCache = make(map[typeId]**decEngine)
+       dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes
 
        return dec
 }
 
-// recvType loads the definition of a type and reloads the Decoder's buffer.
+// recvType loads the definition of a type.
 func (dec *Decoder) recvType(id typeId) {
        // Have we already seen this type?  That's an error
-       if dec.wireType[id] != nil {
+       if id < firstUserId || dec.wireType[id] != nil {
                dec.err = os.ErrorString("gob: duplicate type received")
                return
        }
 
        // Type:
        wire := new(wireType)
-       dec.err = dec.decode(tWireType, reflect.NewValue(wire))
+       dec.err = dec.decodeValue(tWireType, reflect.NewValue(wire))
        if dec.err != nil {
                return
        }
        // Remember we've seen this type.
        dec.wireType[id] = wire
-
-       // Load the next parcel.
-       dec.recvMessage()
-}
-
-// Decode reads the next value from the connection and stores
-// it in the data represented by the empty interface value.
-// The value underlying e must be the correct type for the next
-// data item received, and must be a pointer.
-func (dec *Decoder) Decode(e interface{}) os.Error {
-       value := reflect.NewValue(e)
-       // If e represents a value as opposed to a pointer, the answer won't
-       // get back to the caller.  Make sure it's a pointer.
-       if value.Type().Kind() != reflect.Ptr {
-               dec.err = os.ErrorString("gob: attempt to decode into a non-pointer")
-               return dec.err
-       }
-       return dec.DecodeValue(value)
 }
 
 // recvMessage reads the next count-delimited item from the input. It is the converse
-// of Encoder.writeMessage.
-func (dec *Decoder) recvMessage() {
+// of Encoder.writeMessage. It returns false on EOF or other error reading the message.
+func (dec *Decoder) recvMessage() bool {
        // Read a count.
-       var nbytes uint64
-       nbytes, _, dec.err = decodeUintReader(dec.r, dec.countBuf[0:])
-       if dec.err != nil {
-               return
+       nbytes, _, err := decodeUintReader(dec.r, dec.countBuf)
+       if err != nil {
+               dec.err = err
+               return false
        }
-       dec.readMessage(int(nbytes), dec.r)
+       dec.readMessage(int(nbytes))
+       return dec.err == nil
 }
 
 // readMessage reads the next nbytes bytes from the input.
-func (dec *Decoder) readMessage(nbytes int, r io.Reader) {
+func (dec *Decoder) readMessage(nbytes int) {
        // Allocate the buffer.
-       if nbytes > len(dec.buf) {
-               dec.buf = make([]byte, nbytes+1000)
+       if cap(dec.tmp) < nbytes {
+               dec.tmp = make([]byte, nbytes+100) // room to grow
        }
-       dec.byteBuffer = bytes.NewBuffer(dec.buf[0:nbytes])
+       dec.tmp = dec.tmp[:nbytes]
 
        // Read the data
-       _, dec.err = io.ReadFull(r, dec.buf[0:nbytes])
+       _, dec.err = io.ReadFull(dec.r, dec.tmp)
        if dec.err != nil {
                if dec.err == os.EOF {
                        dec.err = io.ErrUnexpectedEOF
                }
                return
        }
+       dec.buf.Write(dec.tmp)
 }
 
-// decodeValueFromBuffer grabs the next value from the input. The Decoder's
-// buffer already contains data.  If the next item in the buffer is a type
-// descriptor, it will be necessary to reload the buffer; recvType does that.
-func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignoreInterfaceValue, countPresent bool) {
-       for dec.state.b.Len() > 0 {
-               // Receive a type id.
-               id := typeId(dec.state.decodeInt())
+// toInt turns an encoded uint64 into an int, according to the marshaling rules.
+func toInt(x uint64) int64 {
+       i := int64(x >> 1)
+       if x&1 != 0 {
+               i = ^i
+       }
+       return i
+}
+
+func (dec *Decoder) nextInt() int64 {
+       n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+       if err != nil {
+               dec.err = err
+       }
+       return toInt(n)
+}
+
+func (dec *Decoder) nextUint() uint64 {
+       n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+       if err != nil {
+               dec.err = err
+       }
+       return n
+}
 
-               // Is it a new type?
-               if id < 0 { // 0 is the error state, handled above
-                       // If the id is negative, we have a type.
-                       dec.recvType(-id)
-                       if dec.err != nil {
+// decodeTypeSequence parses:
+// TypeSequence
+//     (TypeDefinition DelimitedTypeDefinition*)?
+// and returns the type id of the next value.  It returns -1 at
+// EOF.  Upon return, the remainder of dec.buf is the value to be
+// decoded.  If this is an interface value, it can be ignored by
+// simply resetting that buffer.
+func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
+       for dec.err == nil {
+               if dec.buf.Len() == 0 {
+                       if !dec.recvMessage() {
                                break
                        }
-                       continue
                }
-
-               // Make sure the type has been defined already or is a builtin type (for
-               // top-level singleton values).
-               if dec.wireType[id] == nil && builtinIdToType[id] == nil {
-                       dec.err = errBadType
-                       break
+               // Receive a type id.
+               id := typeId(dec.nextInt())
+               if id >= 0 {
+                       // Value follows.
+                       return id
                }
-               // An interface value is preceded by a byte count.
-               if countPresent {
-                       count := int(dec.state.decodeUint())
-                       if ignoreInterfaceValue {
-                               // An interface value is preceded by a byte count. Just skip that many bytes.
-                               dec.state.b.Next(int(count))
+               // Type definition for (-id) follows.
+               dec.recvType(-id)
+               // When decoding an interface, after a type there may be a
+               // DelimitedValue still in the buffer.  Skip its count.
+               // (Alternatively, the buffer is empty and the byte count
+               // will be absorbed by recvMessage.)
+               if dec.buf.Len() > 0 {
+                       if !isInterface {
+                               dec.err = os.ErrorString("extra data in buffer")
                                break
                        }
-                       // Otherwise fall through and decode it.
+                       dec.nextUint()
                }
-               dec.err = dec.decode(id, value)
-               break
        }
+       return -1
+}
+
+// Decode reads the next value from the connection and stores
+// it in the data represented by the empty interface value.
+// The value underlying e must be the correct type for the next
+// data item received, and must be a pointer.
+func (dec *Decoder) Decode(e interface{}) os.Error {
+       value := reflect.NewValue(e)
+       // If e represents a value as opposed to a pointer, the answer won't
+       // get back to the caller.  Make sure it's a pointer.
+       if value.Type().Kind() != reflect.Ptr {
+               dec.err = os.ErrorString("gob: attempt to decode into a non-pointer")
+               return dec.err
+       }
+       return dec.DecodeValue(value)
 }
 
 // DecodeValue reads the next value from the connection and stores
@@ -154,12 +175,12 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
        dec.mutex.Lock()
        defer dec.mutex.Unlock()
 
+       dec.buf.Reset() // In case data lingers from previous invocation.
        dec.err = nil
-       dec.recvMessage()
-       if dec.err != nil {
-               return dec.err
+       id := dec.decodeTypeSequence(false)
+       if id >= 0 {
+               dec.err = dec.decodeValue(id, value)
        }
-       dec.decodeValueFromBuffer(value, false, false)
        return dec.err
 }
 
index 832bc340fd48474732fe18256757b2672a4c3e1c..2e5ba2487c0280f403e4bdda9d8eb88129e8b653 100644 (file)
@@ -395,17 +395,21 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
        if err != nil {
                error(err)
        }
-       // Send (and maybe first define) the type id.
-       enc.sendTypeDescriptor(typ)
-       // Encode the value into a new buffer.
+       // Define the type id if necessary.
+       enc.sendTypeDescriptor(enc.writer(), state, typ)
+       // Send the type id.
+       enc.sendTypeId(state, typ)
+       // Encode the value into a new buffer.  Any nested type definitions
+       // should be written to b, before the encoded value.
+       enc.pushWriter(b)
        data := new(bytes.Buffer)
        err = enc.encode(data, iv.Elem())
        if err != nil {
                error(err)
        }
-       state.encodeUint(uint64(data.Len()))
-       _, err = state.b.Write(data.Bytes())
-       if err != nil {
+       enc.popWriter()
+       enc.writeMessage(b, data)
+       if enc.err != nil {
                error(err)
        }
 }
index 8869b262982d835569defb09b8a4cb47d2c83dfa..29ba44057ec0900d7a1ad798e33a9a5c1f9725a4 100644 (file)
@@ -16,9 +16,8 @@ import (
 // other side of a connection.
 type Encoder struct {
        mutex      sync.Mutex              // each item must be sent atomically
-       w          io.Writer               // where to send the data
+       w          []io.Writer             // where to send the data
        sent       map[reflect.Type]typeId // which types we've already sent
-       state      *encoderState           // so we can encode integers, strings directly
        countState *encoderState           // stage for writing counts
        buf        []byte                  // for collecting the output.
        err        os.Error
@@ -27,13 +26,27 @@ type Encoder struct {
 // NewEncoder returns a new encoder that will transmit on the io.Writer.
 func NewEncoder(w io.Writer) *Encoder {
        enc := new(Encoder)
-       enc.w = w
+       enc.w = []io.Writer{w}
        enc.sent = make(map[reflect.Type]typeId)
-       enc.state = newEncoderState(enc, new(bytes.Buffer))
        enc.countState = newEncoderState(enc, new(bytes.Buffer))
        return enc
 }
 
+// writer() returns the innermost writer the encoder is using
+func (enc *Encoder) writer() io.Writer {
+       return enc.w[len(enc.w)-1]
+}
+
+// pushWriter adds a writer to the encoder.
+func (enc *Encoder) pushWriter(w io.Writer) {
+       enc.w = append(enc.w, w)
+}
+
+// popWriter pops the innermost writer.
+func (enc *Encoder) popWriter() {
+       enc.w = enc.w[0 : len(enc.w)-1]
+}
+
 func (enc *Encoder) badType(rt reflect.Type) {
        enc.setError(os.ErrorString("gob: can't encode type " + rt.String()))
 }
@@ -42,16 +55,14 @@ func (enc *Encoder) setError(err os.Error) {
        if enc.err == nil { // remember the first.
                enc.err = err
        }
-       enc.state.b.Reset()
 }
 
-// Send the data item preceded by a unsigned count of its length.
-func (enc *Encoder) send() {
-       // Encode the length.
-       enc.countState.encodeUint(uint64(enc.state.b.Len()))
+// writeMessage sends the data item preceded by a unsigned count of its length.
+func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
+       enc.countState.encodeUint(uint64(b.Len()))
        // Build the buffer.
        countLen := enc.countState.b.Len()
-       total := countLen + enc.state.b.Len()
+       total := countLen + b.Len()
        if total > len(enc.buf) {
                enc.buf = make([]byte, total+1000) // extra for growth
        }
@@ -59,15 +70,15 @@ func (enc *Encoder) send() {
        // TODO(r): avoid the extra copy here.
        enc.countState.b.Read(enc.buf[0:countLen])
        // Now the data.
-       enc.state.b.Read(enc.buf[countLen:total])
+       b.Read(enc.buf[countLen:total])
        // Write the data.
-       _, err := enc.w.Write(enc.buf[0:total])
+       _, err := w.Write(enc.buf[0:total])
        if err != nil {
                enc.setError(err)
        }
 }
 
-func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
+func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
        // Drill down to the base type.
        rt, _ := indirect(origt)
 
@@ -112,10 +123,10 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
        }
        // Send the pair (-id, type)
        // Id:
-       enc.state.encodeInt(-int64(info.id))
+       state.encodeInt(-int64(info.id))
        // Type:
-       enc.encode(enc.state.b, reflect.NewValue(info.wire))
-       enc.send()
+       enc.encode(state.b, reflect.NewValue(info.wire))
+       enc.writeMessage(w, state.b)
        if enc.err != nil {
                return
        }
@@ -128,10 +139,10 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
        switch st := rt.(type) {
        case *reflect.StructType:
                for i := 0; i < st.NumField(); i++ {
-                       enc.sendType(st.Field(i).Type)
+                       enc.sendType(w, state, st.Field(i).Type)
                }
        case reflect.ArrayOrSliceType:
-               enc.sendType(st.Elem())
+               enc.sendType(w, state, st.Elem())
        }
        return true
 }
@@ -144,13 +155,13 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
 
 // sendTypeId makes sure the remote side knows about this type.
 // It will send a descriptor if this is the first time the type has been
-// sent.  Regardless, it sends the id.
-func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) {
+// sent.
+func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt reflect.Type) {
        // Make sure the type is known to the other side.
        // First, have we already sent this type?
        if _, alreadySent := enc.sent[rt]; !alreadySent {
                // No, so send it.
-               sent := enc.sendType(rt)
+               sent := enc.sendType(w, state, rt)
                if enc.err != nil {
                        return
                }
@@ -168,9 +179,12 @@ func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) {
                        enc.sent[rt] = info.id
                }
        }
+}
 
+// sendTypeId sends the id, which must have already been defined.
+func (enc *Encoder) sendTypeId(state *encoderState, rt reflect.Type) {
        // Identify the type of this top-level value.
-       enc.state.encodeInt(int64(enc.sent[rt]))
+       state.encodeInt(int64(enc.sent[rt]))
 }
 
 // EncodeValue transmits the data item represented by the reflection value,
@@ -181,26 +195,26 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
        enc.mutex.Lock()
        defer enc.mutex.Unlock()
 
+       // Remove any nested writers remaining due to previous errors.
+       enc.w = enc.w[0:1]
+
        enc.err = nil
        rt, _ := indirect(value.Type())
 
-       // Sanity check only: encoder should never come in with data present.
-       if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 {
-               enc.err = os.ErrorString("encoder: buffer not empty")
-               return enc.err
-       }
+       state := newEncoderState(enc, new(bytes.Buffer))
 
-       enc.sendTypeDescriptor(rt)
+       enc.sendTypeDescriptor(enc.writer(), state, rt)
+       enc.sendTypeId(state, rt)
        if enc.err != nil {
                return enc.err
        }
 
        // Encode the object.
-       err := enc.encode(enc.state.b, value)
+       err := enc.encode(state.b, value)
        if err != nil {
                enc.setError(err)
        } else {
-               enc.send()
+               enc.writeMessage(enc.writer(), state.b)
        }
 
        return enc.err
index 402fd2a13db094aa0e4798a67cc89f446f38aaea..d0449bd649bbb3df0679da339da12edf46894689 100644 (file)
@@ -383,3 +383,47 @@ func TestInterfaceIndirect(t *testing.T) {
                t.Fatal("decode error:", err)
        }
 }
+
+// Another bug from golang-nuts, involving nested interfaces.
+type Bug0Outer struct {
+       Bug0Field interface{}
+}
+
+type Bug0Inner struct {
+       A int
+}
+
+func TestNestedInterfaces(t *testing.T) {
+       var buf bytes.Buffer
+       e := NewEncoder(&buf)
+       d := NewDecoder(&buf)
+       Register(new(Bug0Outer))
+       Register(new(Bug0Inner))
+       f := &Bug0Outer{&Bug0Outer{&Bug0Inner{7}}}
+       var v interface{} = f
+       err := e.Encode(&v)
+       if err != nil {
+               t.Fatal("Encode:", err)
+       }
+       Debug(bytes.NewBuffer(buf.Bytes()))
+       err = d.Decode(&v)
+       if err != nil {
+               t.Fatal("Decode:", err)
+       }
+       // Make sure it decoded correctly.
+       outer1, ok := v.(*Bug0Outer)
+       if !ok {
+               t.Fatalf("v not Bug0Outer: %T", v)
+       }
+       outer2, ok := outer1.Bug0Field.(*Bug0Outer)
+       if !ok {
+               t.Fatalf("v.Bug0Field not Bug0Outer: %T", outer1.Bug0Field)
+       }
+       inner, ok := outer2.Bug0Field.(*Bug0Inner)
+       if !ok {
+               t.Fatalf("v.Bug0Field.Bug0Field not Bug0Inner: %T", outer2.Bug0Field)
+       }
+       if inner.A != 7 {
+               t.Fatalf("final value %d; expected %d", inner.A, 7)
+       }
+}