t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes())
}
}
- decState := newDecodeState(nil, &b)
+ decState := newDecodeState(nil, b)
for u := uint64(0); ; u = (u + 1) * 7 {
b.Reset()
encState.encodeUint(u)
var b = new(bytes.Buffer)
encState := newEncoderState(nil, b)
encState.encodeInt(i)
- decState := newDecodeState(nil, &b)
+ decState := newDecodeState(nil, b)
decState.buf = make([]byte, 8)
j := decState.decodeInt()
if i != j {
func newDecodeStateFromData(data []byte) *decodeState {
b := bytes.NewBuffer(data)
- state := newDecodeState(nil, &b)
+ state := newDecodeState(nil, b)
state.fieldnum = -1
return state
}
}
}
}
-
}
// A struct with all basic types, stored in interfaces.
int(1), int8(1), int16(1), int32(1), int64(1),
uint(1), uint8(1), uint16(1), uint32(1), uint64(1),
float32(1), 1.0,
- complex64(0i), complex128(0i),
+ complex64(1i), complex128(1i),
true,
"hello",
[]byte("sailor"),
dec *Decoder
// The buffer is stored with an extra indirection because it may be replaced
// if we load a type during decode (when reading an interface value).
- b **bytes.Buffer
+ b *bytes.Buffer
fieldnum int // the last field number read.
buf []byte
}
-func newDecodeState(dec *Decoder, b **bytes.Buffer) *decodeState {
+// We pass the bytes.Buffer separately for easier testing of the infrastructure
+// without requiring a full Decoder.
+func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decodeState {
d := new(decodeState)
d.dec = dec
- d.b = b
+ d.b = buf
d.buf = make([]byte, uint64Size)
return d
}
return *(*uintptr)(up)
}
-func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) (err os.Error) {
+func (dec *Decoder) decodeSingle(engine *decEngine, rtyp reflect.Type, p uintptr, indir int) (err os.Error) {
defer catchError(&err)
p = allocate(rtyp, p, indir)
- state := newDecodeState(dec, b)
+ state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField
basep := p
delta := int(state.decodeUint())
return nil
}
-func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) (err os.Error) {
+func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, p uintptr, indir int) (err os.Error) {
defer catchError(&err)
p = allocate(rtyp, p, indir)
- state := newDecodeState(dec, b)
+ state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1
basep := p
for state.b.Len() > 0 {
return nil
}
-func (dec *Decoder) ignoreStruct(engine *decEngine, b **bytes.Buffer) (err os.Error) {
+func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
defer catchError(&err)
- state := newDecodeState(dec, b)
+ state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1
for state.b.Len() > 0 {
delta := int(state.decodeUint())
if !ok {
errorf("gob: name not registered for interface: %q", name)
}
+ // Read the type id of the concrete value.
+ concreteId := dec.decodeTypeSequence(true)
+ if concreteId < 0 {
+ error(dec.err)
+ }
+ // Byte count of value is next; we don't care what it is (it's there
+ // in case we want to ignore the value by skipping it completely).
+ state.decodeUint()
// Read the concrete value.
value := reflect.MakeZero(typ)
- dec.decodeValueFromBuffer(value, false, true)
+ dec.decodeValue(concreteId, value)
if dec.err != nil {
error(dec.err)
}
if err != nil {
error(err)
}
- dec.decodeValueFromBuffer(nil, true, true)
- if dec.err != nil {
- error(err)
+ id := dec.decodeTypeSequence(true)
+ if id < 0 {
+ error(dec.err)
}
+ // At this point, the decoder buffer contains the value. Just toss it.
+ state.b.Reset()
}
// Index by Go types.
}
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs
- err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir)
+ err = dec.decodeStruct(*enginePtr, t, uintptr(p), i.indir)
if err != nil {
error(err)
}
}
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs
- state.dec.ignoreStruct(*enginePtr, state.b)
+ state.dec.ignoreStruct(*enginePtr)
}
}
}
if t, ok := builtinIdToType[remoteId]; ok {
wireStruct, _ = t.(*structType)
} else {
- wireStruct = dec.wireType[remoteId].StructT
+ wire := dec.wireType[remoteId]
+ if wire == nil {
+ error(errBadType)
+ }
+ wireStruct = wire.StructT
}
if wireStruct == nil {
errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String())
return
}
-func (dec *Decoder) decode(wireId typeId, val reflect.Value) os.Error {
+func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error {
// Dereference down to the underlying struct type.
rt, indir := indirect(val.Type())
enginePtr, err := dec.getDecEnginePtr(wireId, rt)
name := rt.Name()
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
}
- return dec.decodeStruct(engine, st, dec.state.b, uintptr(val.Addr()), indir)
+ return dec.decodeStruct(engine, st, uintptr(val.Addr()), indir)
}
- return dec.decodeSingle(engine, rt, dec.state.b, uintptr(val.Addr()), indir)
+ return dec.decodeSingle(engine, rt, uintptr(val.Addr()), indir)
}
func init() {
type Decoder struct {
mutex sync.Mutex // each item must be received atomically
r io.Reader // source of the data
+ buf bytes.Buffer // buffer for more efficient i/o from r
wireType map[typeId]*wireType // map from remote ID to local description
decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
ignorerCache map[typeId]**decEngine // ditto for ignored objects
- state *decodeState // reads data from in-memory buffer
countState *decodeState // reads counts from wire
- buf []byte
- countBuf [9]byte // counts may be uint64s (unlikely!), require 9 bytes
- byteBuffer *bytes.Buffer
+ countBuf []byte // used for decoding integers while parsing messages
+ tmp []byte // temporary storage for i/o; saves reallocating
err os.Error
}
dec := new(Decoder)
dec.r = r
dec.wireType = make(map[typeId]*wireType)
- dec.state = newDecodeState(dec, &dec.byteBuffer) // buffer set in Decode()
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
dec.ignorerCache = make(map[typeId]**decEngine)
+ dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes
return dec
}
-// recvType loads the definition of a type and reloads the Decoder's buffer.
+// recvType loads the definition of a type.
func (dec *Decoder) recvType(id typeId) {
// Have we already seen this type? That's an error
- if dec.wireType[id] != nil {
+ if id < firstUserId || dec.wireType[id] != nil {
dec.err = os.ErrorString("gob: duplicate type received")
return
}
// Type:
wire := new(wireType)
- dec.err = dec.decode(tWireType, reflect.NewValue(wire))
+ dec.err = dec.decodeValue(tWireType, reflect.NewValue(wire))
if dec.err != nil {
return
}
// Remember we've seen this type.
dec.wireType[id] = wire
-
- // Load the next parcel.
- dec.recvMessage()
-}
-
-// Decode reads the next value from the connection and stores
-// it in the data represented by the empty interface value.
-// The value underlying e must be the correct type for the next
-// data item received, and must be a pointer.
-func (dec *Decoder) Decode(e interface{}) os.Error {
- value := reflect.NewValue(e)
- // If e represents a value as opposed to a pointer, the answer won't
- // get back to the caller. Make sure it's a pointer.
- if value.Type().Kind() != reflect.Ptr {
- dec.err = os.ErrorString("gob: attempt to decode into a non-pointer")
- return dec.err
- }
- return dec.DecodeValue(value)
}
// recvMessage reads the next count-delimited item from the input. It is the converse
-// of Encoder.writeMessage.
-func (dec *Decoder) recvMessage() {
+// of Encoder.writeMessage. It returns false on EOF or other error reading the message.
+func (dec *Decoder) recvMessage() bool {
// Read a count.
- var nbytes uint64
- nbytes, _, dec.err = decodeUintReader(dec.r, dec.countBuf[0:])
- if dec.err != nil {
- return
+ nbytes, _, err := decodeUintReader(dec.r, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ return false
}
- dec.readMessage(int(nbytes), dec.r)
+ dec.readMessage(int(nbytes))
+ return dec.err == nil
}
// readMessage reads the next nbytes bytes from the input.
-func (dec *Decoder) readMessage(nbytes int, r io.Reader) {
+func (dec *Decoder) readMessage(nbytes int) {
// Allocate the buffer.
- if nbytes > len(dec.buf) {
- dec.buf = make([]byte, nbytes+1000)
+ if cap(dec.tmp) < nbytes {
+ dec.tmp = make([]byte, nbytes+100) // room to grow
}
- dec.byteBuffer = bytes.NewBuffer(dec.buf[0:nbytes])
+ dec.tmp = dec.tmp[:nbytes]
// Read the data
- _, dec.err = io.ReadFull(r, dec.buf[0:nbytes])
+ _, dec.err = io.ReadFull(dec.r, dec.tmp)
if dec.err != nil {
if dec.err == os.EOF {
dec.err = io.ErrUnexpectedEOF
}
return
}
+ dec.buf.Write(dec.tmp)
}
-// decodeValueFromBuffer grabs the next value from the input. The Decoder's
-// buffer already contains data. If the next item in the buffer is a type
-// descriptor, it will be necessary to reload the buffer; recvType does that.
-func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignoreInterfaceValue, countPresent bool) {
- for dec.state.b.Len() > 0 {
- // Receive a type id.
- id := typeId(dec.state.decodeInt())
+// toInt turns an encoded uint64 into an int, according to the marshaling rules.
+func toInt(x uint64) int64 {
+ i := int64(x >> 1)
+ if x&1 != 0 {
+ i = ^i
+ }
+ return i
+}
+
+func (dec *Decoder) nextInt() int64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return toInt(n)
+}
+
+func (dec *Decoder) nextUint() uint64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return n
+}
- // Is it a new type?
- if id < 0 { // 0 is the error state, handled above
- // If the id is negative, we have a type.
- dec.recvType(-id)
- if dec.err != nil {
+// decodeTypeSequence parses:
+// TypeSequence
+// (TypeDefinition DelimitedTypeDefinition*)?
+// and returns the type id of the next value. It returns -1 at
+// EOF. Upon return, the remainder of dec.buf is the value to be
+// decoded. If this is an interface value, it can be ignored by
+// simply resetting that buffer.
+func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
+ for dec.err == nil {
+ if dec.buf.Len() == 0 {
+ if !dec.recvMessage() {
break
}
- continue
}
-
- // Make sure the type has been defined already or is a builtin type (for
- // top-level singleton values).
- if dec.wireType[id] == nil && builtinIdToType[id] == nil {
- dec.err = errBadType
- break
+ // Receive a type id.
+ id := typeId(dec.nextInt())
+ if id >= 0 {
+ // Value follows.
+ return id
}
- // An interface value is preceded by a byte count.
- if countPresent {
- count := int(dec.state.decodeUint())
- if ignoreInterfaceValue {
- // An interface value is preceded by a byte count. Just skip that many bytes.
- dec.state.b.Next(int(count))
+ // Type definition for (-id) follows.
+ dec.recvType(-id)
+ // When decoding an interface, after a type there may be a
+ // DelimitedValue still in the buffer. Skip its count.
+ // (Alternatively, the buffer is empty and the byte count
+ // will be absorbed by recvMessage.)
+ if dec.buf.Len() > 0 {
+ if !isInterface {
+ dec.err = os.ErrorString("extra data in buffer")
break
}
- // Otherwise fall through and decode it.
+ dec.nextUint()
}
- dec.err = dec.decode(id, value)
- break
}
+ return -1
+}
+
+// Decode reads the next value from the connection and stores
+// it in the data represented by the empty interface value.
+// The value underlying e must be the correct type for the next
+// data item received, and must be a pointer.
+func (dec *Decoder) Decode(e interface{}) os.Error {
+ value := reflect.NewValue(e)
+ // If e represents a value as opposed to a pointer, the answer won't
+ // get back to the caller. Make sure it's a pointer.
+ if value.Type().Kind() != reflect.Ptr {
+ dec.err = os.ErrorString("gob: attempt to decode into a non-pointer")
+ return dec.err
+ }
+ return dec.DecodeValue(value)
}
// DecodeValue reads the next value from the connection and stores
dec.mutex.Lock()
defer dec.mutex.Unlock()
+ dec.buf.Reset() // In case data lingers from previous invocation.
dec.err = nil
- dec.recvMessage()
- if dec.err != nil {
- return dec.err
+ id := dec.decodeTypeSequence(false)
+ if id >= 0 {
+ dec.err = dec.decodeValue(id, value)
}
- dec.decodeValueFromBuffer(value, false, false)
return dec.err
}
if err != nil {
error(err)
}
- // Send (and maybe first define) the type id.
- enc.sendTypeDescriptor(typ)
- // Encode the value into a new buffer.
+ // Define the type id if necessary.
+ enc.sendTypeDescriptor(enc.writer(), state, typ)
+ // Send the type id.
+ enc.sendTypeId(state, typ)
+ // Encode the value into a new buffer. Any nested type definitions
+ // should be written to b, before the encoded value.
+ enc.pushWriter(b)
data := new(bytes.Buffer)
err = enc.encode(data, iv.Elem())
if err != nil {
error(err)
}
- state.encodeUint(uint64(data.Len()))
- _, err = state.b.Write(data.Bytes())
- if err != nil {
+ enc.popWriter()
+ enc.writeMessage(b, data)
+ if enc.err != nil {
error(err)
}
}
// other side of a connection.
type Encoder struct {
mutex sync.Mutex // each item must be sent atomically
- w io.Writer // where to send the data
+ w []io.Writer // where to send the data
sent map[reflect.Type]typeId // which types we've already sent
- state *encoderState // so we can encode integers, strings directly
countState *encoderState // stage for writing counts
buf []byte // for collecting the output.
err os.Error
// NewEncoder returns a new encoder that will transmit on the io.Writer.
func NewEncoder(w io.Writer) *Encoder {
enc := new(Encoder)
- enc.w = w
+ enc.w = []io.Writer{w}
enc.sent = make(map[reflect.Type]typeId)
- enc.state = newEncoderState(enc, new(bytes.Buffer))
enc.countState = newEncoderState(enc, new(bytes.Buffer))
return enc
}
+// writer() returns the innermost writer the encoder is using
+func (enc *Encoder) writer() io.Writer {
+ return enc.w[len(enc.w)-1]
+}
+
+// pushWriter adds a writer to the encoder.
+func (enc *Encoder) pushWriter(w io.Writer) {
+ enc.w = append(enc.w, w)
+}
+
+// popWriter pops the innermost writer.
+func (enc *Encoder) popWriter() {
+ enc.w = enc.w[0 : len(enc.w)-1]
+}
+
func (enc *Encoder) badType(rt reflect.Type) {
enc.setError(os.ErrorString("gob: can't encode type " + rt.String()))
}
if enc.err == nil { // remember the first.
enc.err = err
}
- enc.state.b.Reset()
}
-// Send the data item preceded by a unsigned count of its length.
-func (enc *Encoder) send() {
- // Encode the length.
- enc.countState.encodeUint(uint64(enc.state.b.Len()))
+// writeMessage sends the data item preceded by a unsigned count of its length.
+func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
+ enc.countState.encodeUint(uint64(b.Len()))
// Build the buffer.
countLen := enc.countState.b.Len()
- total := countLen + enc.state.b.Len()
+ total := countLen + b.Len()
if total > len(enc.buf) {
enc.buf = make([]byte, total+1000) // extra for growth
}
// TODO(r): avoid the extra copy here.
enc.countState.b.Read(enc.buf[0:countLen])
// Now the data.
- enc.state.b.Read(enc.buf[countLen:total])
+ b.Read(enc.buf[countLen:total])
// Write the data.
- _, err := enc.w.Write(enc.buf[0:total])
+ _, err := w.Write(enc.buf[0:total])
if err != nil {
enc.setError(err)
}
}
-func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
+func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
// Drill down to the base type.
rt, _ := indirect(origt)
}
// Send the pair (-id, type)
// Id:
- enc.state.encodeInt(-int64(info.id))
+ state.encodeInt(-int64(info.id))
// Type:
- enc.encode(enc.state.b, reflect.NewValue(info.wire))
- enc.send()
+ enc.encode(state.b, reflect.NewValue(info.wire))
+ enc.writeMessage(w, state.b)
if enc.err != nil {
return
}
switch st := rt.(type) {
case *reflect.StructType:
for i := 0; i < st.NumField(); i++ {
- enc.sendType(st.Field(i).Type)
+ enc.sendType(w, state, st.Field(i).Type)
}
case reflect.ArrayOrSliceType:
- enc.sendType(st.Elem())
+ enc.sendType(w, state, st.Elem())
}
return true
}
// sendTypeId makes sure the remote side knows about this type.
// It will send a descriptor if this is the first time the type has been
-// sent. Regardless, it sends the id.
-func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) {
+// sent.
+func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, rt reflect.Type) {
// Make sure the type is known to the other side.
// First, have we already sent this type?
if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it.
- sent := enc.sendType(rt)
+ sent := enc.sendType(w, state, rt)
if enc.err != nil {
return
}
enc.sent[rt] = info.id
}
}
+}
+// sendTypeId sends the id, which must have already been defined.
+func (enc *Encoder) sendTypeId(state *encoderState, rt reflect.Type) {
// Identify the type of this top-level value.
- enc.state.encodeInt(int64(enc.sent[rt]))
+ state.encodeInt(int64(enc.sent[rt]))
}
// EncodeValue transmits the data item represented by the reflection value,
enc.mutex.Lock()
defer enc.mutex.Unlock()
+ // Remove any nested writers remaining due to previous errors.
+ enc.w = enc.w[0:1]
+
enc.err = nil
rt, _ := indirect(value.Type())
- // Sanity check only: encoder should never come in with data present.
- if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 {
- enc.err = os.ErrorString("encoder: buffer not empty")
- return enc.err
- }
+ state := newEncoderState(enc, new(bytes.Buffer))
- enc.sendTypeDescriptor(rt)
+ enc.sendTypeDescriptor(enc.writer(), state, rt)
+ enc.sendTypeId(state, rt)
if enc.err != nil {
return enc.err
}
// Encode the object.
- err := enc.encode(enc.state.b, value)
+ err := enc.encode(state.b, value)
if err != nil {
enc.setError(err)
} else {
- enc.send()
+ enc.writeMessage(enc.writer(), state.b)
}
return enc.err
t.Fatal("decode error:", err)
}
}
+
+// Another bug from golang-nuts, involving nested interfaces.
+type Bug0Outer struct {
+ Bug0Field interface{}
+}
+
+type Bug0Inner struct {
+ A int
+}
+
+func TestNestedInterfaces(t *testing.T) {
+ var buf bytes.Buffer
+ e := NewEncoder(&buf)
+ d := NewDecoder(&buf)
+ Register(new(Bug0Outer))
+ Register(new(Bug0Inner))
+ f := &Bug0Outer{&Bug0Outer{&Bug0Inner{7}}}
+ var v interface{} = f
+ err := e.Encode(&v)
+ if err != nil {
+ t.Fatal("Encode:", err)
+ }
+ Debug(bytes.NewBuffer(buf.Bytes()))
+ err = d.Decode(&v)
+ if err != nil {
+ t.Fatal("Decode:", err)
+ }
+ // Make sure it decoded correctly.
+ outer1, ok := v.(*Bug0Outer)
+ if !ok {
+ t.Fatalf("v not Bug0Outer: %T", v)
+ }
+ outer2, ok := outer1.Bug0Field.(*Bug0Outer)
+ if !ok {
+ t.Fatalf("v.Bug0Field not Bug0Outer: %T", outer1.Bug0Field)
+ }
+ inner, ok := outer2.Bug0Field.(*Bug0Inner)
+ if !ok {
+ t.Fatalf("v.Bug0Field.Bug0Field not Bug0Inner: %T", outer2.Bug0Field)
+ }
+ if inner.A != 7 {
+ t.Fatalf("final value %d; expected %d", inner.A, 7)
+ }
+}