--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "gob";
+ "io";
+ "os";
+ "reflect";
+ "sync";
+)
+
+type Decoder struct {
+ sync.Mutex; // each item must be received atomically
+ seen map[TypeId] *wireType; // which types we've already seen described
+ state *DecState; // so we can encode integers, strings directly
+}
+
+func NewDecoder(r io.Reader) *Decoder {
+ dec := new(Decoder);
+ dec.seen = make(map[TypeId] *wireType);
+ dec.state = new(DecState);
+ dec.state.r = r; // the rest isn't important; all we need is buffer and reader
+
+ return dec;
+}
+
+func (dec *Decoder) recvType(id TypeId) {
+ // Have we already seen this type? That's an error
+ if wt_, alreadySeen := dec.seen[id]; alreadySeen {
+ dec.state.err = os.ErrorString("gob: duplicate type received");
+ return
+ }
+
+ // Type:
+ wire := new(wireType);
+ Decode(dec.state.r, wire);
+ // Remember we've seen this type.
+ dec.seen[id] = wire;
+}
+
+// The value underlying e must be the correct type for the next
+// value to be received for this decoder.
+func (dec *Decoder) Decode(e interface{}) os.Error {
+ rt, indir := indirect(reflect.Typeof(e));
+
+ // Make sure we're single-threaded through here.
+ dec.Lock();
+ defer dec.Unlock();
+
+ var id TypeId;
+ for dec.state.err == nil {
+ // Receive a type id.
+ id = TypeId(DecodeInt(dec.state));
+
+ // If the id is positive, we have a value. 0 is the error state
+ if id >= 0 {
+ break;
+ }
+
+ // The id is negative; a type descriptor follows.
+ dec.recvType(-id);
+ }
+ if dec.state.err != nil {
+ return dec.state.err
+ }
+
+ info := getTypeInfo(rt);
+
+ // Check type compatibility.
+ // TODO(r): need to make the decoder work correctly if the wire type is compatible
+ // but not equal to the local type (e.g, extra fields).
+ if info.wire.name != dec.seen[id].name {
+ dec.state.err = os.ErrorString("gob decode: incorrect type for wire value");
+ return dec.state.err
+ }
+
+ // Receive a value.
+ Decode(dec.state.r, e);
+
+ // Release and return.
+ return dec.state.err
+}
import (
"bytes";
-"fmt"; // DELETE
"gob";
"os";
"reflect";
next *ET1;
}
+// Like ET1 but with a different name for a field
+type ET3 struct {
+ a int;
+ et2 *ET2;
+ differentNext *ET1;
+}
+
+// Like ET1 but with a different type for a field
+type ET4 struct {
+ a int;
+ et2 *ET1;
+ next *ET2;
+}
+
+// Like ET1 but with a different type for a self-referencing field
+type ET5 struct {
+ a int;
+ et2 *ET2;
+ next *ET1;
+}
+
func TestBasicEncoder(t *testing.T) {
b := new(bytes.Buffer);
enc := NewEncoder(b);
t.Error("2nd round: not at eof;", b.Len(), "bytes left")
}
}
+
+func TestEncoderDecoder(t *testing.T) {
+ b := new(bytes.Buffer);
+ enc := NewEncoder(b);
+ et1 := new(ET1);
+ et1.a = 7;
+ et1.et2 = new(ET2);
+ enc.Encode(et1);
+ if enc.state.err != nil {
+ t.Error("encoder fail:", enc.state.err)
+ }
+ dec := NewDecoder(b);
+ newEt1 := new(ET1);
+ dec.Decode(newEt1);
+ if dec.state.err != nil {
+ t.Fatalf("error decoding ET1:", dec.state.err);
+ }
+
+ if !reflect.DeepEqual(et1, newEt1) {
+ t.Fatalf("invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1);
+ }
+ if b.Len() != 0 {
+ t.Error("not at eof;", b.Len(), "bytes left")
+ }
+
+ enc.Encode(et1);
+ newEt1 = new(ET1);
+ dec.Decode(newEt1);
+ if dec.state.err != nil {
+ t.Fatalf("round 2: error decoding ET1:", dec.state.err);
+ }
+ if !reflect.DeepEqual(et1, newEt1) {
+ t.Fatalf("round 2: invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1);
+ }
+ if b.Len() != 0 {
+ t.Error("round 2: not at eof;", b.Len(), "bytes left")
+ }
+
+ // Now test with a running encoder/decoder pair that we recognize a type mismatch.
+ enc.Encode(et1);
+ if enc.state.err != nil {
+ t.Error("round 3: encoder fail:", enc.state.err)
+ }
+ newEt2 := new(ET2);
+ dec.Decode(newEt2);
+ if dec.state.err == nil {
+ t.Fatalf("round 3: expected `bad type' error decoding ET2");
+ }
+}
+
+// Run one value through the encoder/decoder, but use the wrong type.
+func badTypeCheck(e interface{}, msg string, t *testing.T) {
+ b := new(bytes.Buffer);
+ enc := NewEncoder(b);
+ et1 := new(ET1);
+ et1.a = 7;
+ et1.et2 = new(ET2);
+ enc.Encode(et1);
+ if enc.state.err != nil {
+ t.Error("encoder fail:", enc.state.err)
+ }
+ dec := NewDecoder(b);
+ dec.Decode(e);
+ if dec.state.err == nil {
+ t.Error("expected error for", msg);
+ }
+}
+
+// Test that we recognize a bad type the first time.
+func TestWrongTypeDecoder(t *testing.T) {
+ badTypeCheck(new(ET2), "different number of fields", t);
+ badTypeCheck(new(ET3), "different name of field", t);
+ badTypeCheck(new(ET4), "different type of field", t);
+ badTypeCheck(new(ET5), "different type of self-reference field", t);
+}