]> Cypherpunks repositories - gostls13.git/commitdiff
gobs: error cleanup part 1.
authorRob Pike <r@golang.org>
Fri, 22 Oct 2010 22:16:34 +0000 (15:16 -0700)
committerRob Pike <r@golang.org>
Fri, 22 Oct 2010 22:16:34 +0000 (15:16 -0700)
Remove err from the encoderState and decoderState types, so we're
not always copying to and from various copies of the error, and then
use panic/recover to eliminate lots of error checking.

another pass might take a crack at the same thing for the compilation phase.

R=rsc
CC=golang-dev
https://golang.org/cl/2660042

src/pkg/gob/Makefile
src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/decoder.go
src/pkg/gob/encode.go
src/pkg/gob/encoder.go
src/pkg/gob/encoder_test.go
src/pkg/gob/error.go [new file with mode: 0644]

index 77ec9d98cebe8be3af9a4b2e3b03f168bad03afb..68007c189e20d62032cfe7028e6d344cb962aa4a 100644 (file)
@@ -11,6 +11,7 @@ GOFILES=\
        doc.go\
        encode.go\
        encoder.go\
+       error.go\
        type.go\
 
 include ../../Make.pkg
index ba97f51a1baecb49428defd2ec7f7f57db5541ac..2e52a0f1dd8e43fb0a7a40d464a8f64910faf806 100644 (file)
@@ -37,16 +37,23 @@ var encodeT = []EncodeT{
        {1 << 63, []byte{0xF8, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
 }
 
+// testError is meant to be used as a deferred function to turn a panic(gobError) into a
+// plain test.Error call.
+func testError(t *testing.T) {
+       if e := recover(); e != nil {
+               t.Error(e.(gobError).Error) // Will re-panic if not one of our errors, such as a runtime error.
+       }
+       return
+}
+
 // Test basic encode/decode routines for unsigned integers
 func TestUintCodec(t *testing.T) {
+       defer testError(t)
        b := new(bytes.Buffer)
        encState := newEncoderState(b)
        for _, tt := range encodeT {
                b.Reset()
                encodeUint(encState, tt.x)
-               if encState.err != nil {
-                       t.Error("encodeUint:", tt.x, encState.err)
-               }
                if !bytes.Equal(tt.b, b.Bytes()) {
                        t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes())
                }
@@ -55,13 +62,7 @@ func TestUintCodec(t *testing.T) {
        for u := uint64(0); ; u = (u + 1) * 7 {
                b.Reset()
                encodeUint(encState, u)
-               if encState.err != nil {
-                       t.Error("encodeUint:", u, encState.err)
-               }
                v := decodeUint(decState)
-               if decState.err != nil {
-                       t.Error("DecodeUint:", u, decState.err)
-               }
                if u != v {
                        t.Errorf("Encode/Decode: sent %#x received %#x", u, v)
                }
@@ -72,18 +73,13 @@ func TestUintCodec(t *testing.T) {
 }
 
 func verifyInt(i int64, t *testing.T) {
+       defer testError(t)
        var b = new(bytes.Buffer)
        encState := newEncoderState(b)
        encodeInt(encState, i)
-       if encState.err != nil {
-               t.Error("encodeInt:", i, encState.err)
-       }
        decState := newDecodeState(&b)
        decState.buf = make([]byte, 8)
        j := decodeInt(decState)
-       if decState.err != nil {
-               t.Error("DecodeInt:", i, decState.err)
-       }
        if i != j {
                t.Errorf("Encode/Decode: sent %#x received %#x", uint64(i), uint64(j))
        }
@@ -320,10 +316,8 @@ func TestScalarEncInstructions(t *testing.T) {
 }
 
 func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) {
+       defer testError(t)
        v := int(decodeUint(state))
-       if state.err != nil {
-               t.Fatalf("decoding %s field: %v", typ, state.err)
-       }
        if v+state.fieldnum != 6 {
                t.Fatalf("decoding field number %d, got %d", 6, v+state.fieldnum)
        }
index 5791c37ecb7c98f0b96e409dcaa4bbed6542f2e9..96d3176847f98f845175f8527e4f4cadd24c4818 100644 (file)
@@ -27,8 +27,8 @@ var (
 type decodeState struct {
        // The buffer is stored with an extra indirection because it may be replaced
        // if we load a type during decode (when reading an interface value).
-       b        **bytes.Buffer
-       err      os.Error
+       b **bytes.Buffer
+       //      err      os.Error
        fieldnum int // the last field number read.
        buf      []byte
 }
@@ -80,21 +80,21 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, err os.Error) {
 // Sets state.err.  If state.err is already non-nil, it does nothing.
 // Does not check for overflow.
 func decodeUint(state *decodeState) (x uint64) {
-       if state.err != nil {
-               return
+       b, err := state.b.ReadByte()
+       if err != nil {
+               error(err)
        }
-       var b uint8
-       b, state.err = state.b.ReadByte()
        if b <= 0x7f { // includes state.err != nil
                return uint64(b)
        }
        nb := -int(int8(b))
        if nb > uint64Size {
-               state.err = errBadUint
-               return
+               error(errBadUint)
+       }
+       n, err := state.b.Read(state.buf[0:nb])
+       if err != nil {
+               error(err)
        }
-       var n int
-       n, state.err = state.b.Read(state.buf[0:nb])
        // Don't need to check error; it's safe to loop regardless.
        // Could check that the high byte is zero but it's not worth it.
        for i := 0; i < n; i++ {
@@ -109,9 +109,6 @@ func decodeUint(state *decodeState) (x uint64) {
 // Does not check for overflow.
 func decodeInt(state *decodeState) int64 {
        x := decodeUint(state)
-       if state.err != nil {
-               return 0
-       }
        if x&1 != 0 {
                return ^int64(x >> 1)
        }
@@ -176,7 +173,7 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
        v := decodeInt(state)
        if v < math.MinInt8 || math.MaxInt8 < v {
-               state.err = i.ovfl
+               error(i.ovfl)
        } else {
                *(*int8)(p) = int8(v)
        }
@@ -191,7 +188,7 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
        v := decodeUint(state)
        if math.MaxUint8 < v {
-               state.err = i.ovfl
+               error(i.ovfl)
        } else {
                *(*uint8)(p) = uint8(v)
        }
@@ -206,7 +203,7 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
        v := decodeInt(state)
        if v < math.MinInt16 || math.MaxInt16 < v {
-               state.err = i.ovfl
+               error(i.ovfl)
        } else {
                *(*int16)(p) = int16(v)
        }
@@ -221,7 +218,7 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
        v := decodeUint(state)
        if math.MaxUint16 < v {
-               state.err = i.ovfl
+               error(i.ovfl)
        } else {
                *(*uint16)(p) = uint16(v)
        }
@@ -236,7 +233,7 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
        v := decodeInt(state)
        if v < math.MinInt32 || math.MaxInt32 < v {
-               state.err = i.ovfl
+               error(i.ovfl)
        } else {
                *(*int32)(p) = int32(v)
        }
@@ -251,7 +248,7 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
        v := decodeUint(state)
        if math.MaxUint32 < v {
-               state.err = i.ovfl
+               error(i.ovfl)
        } else {
                *(*uint32)(p) = uint32(v)
        }
@@ -300,7 +297,7 @@ func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
        }
        // +Inf is OK in both 32- and 64-bit floats.  Underflow is always OK.
        if math.MaxFloat32 < av && av <= math.MaxFloat64 {
-               state.err = i.ovfl
+               error(i.ovfl)
        } else {
                *(*float32)(p) = float32(v)
        }
@@ -407,15 +404,15 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
        return *(*uintptr)(up)
 }
 
-func decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) os.Error {
+func decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) (err os.Error) {
+       defer catchError(&err)
        p = allocate(rtyp, p, indir)
        state := newDecodeState(b)
        state.fieldnum = singletonField
        basep := p
        delta := int(decodeUint(state))
        if delta != 0 {
-               state.err = os.ErrorString("gob decode: corrupted data: non-zero delta for singleton")
-               return state.err
+               errorf("gob decode: corrupted data: non-zero delta for singleton")
        }
        instr := &engine.instr[singletonField]
        ptr := unsafe.Pointer(basep) // offset will be zero
@@ -423,26 +420,26 @@ func decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uint
                ptr = decIndirect(ptr, instr.indir)
        }
        instr.op(instr, state, ptr)
-       return state.err
+       return nil
 }
 
-func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) os.Error {
+func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) (err os.Error) {
+       defer catchError(&err)
        p = allocate(rtyp, p, indir)
        state := newDecodeState(b)
        state.fieldnum = -1
        basep := p
-       for state.b.Len() > 0 && state.err == nil {
+       for state.b.Len() > 0 {
                delta := int(decodeUint(state))
                if delta < 0 {
-                       state.err = os.ErrorString("gob decode: corrupted data: negative delta")
-                       break
+                       errorf("gob decode: corrupted data: negative delta")
                }
-               if state.err != nil || delta == 0 { // struct terminator is zero delta fieldnum
+               if delta == 0 { // struct terminator is zero delta fieldnum
                        break
                }
                fieldnum := state.fieldnum + delta
                if fieldnum >= len(engine.instr) {
-                       state.err = errRange
+                       error(errRange)
                        break
                }
                instr := &engine.instr[fieldnum]
@@ -453,36 +450,35 @@ func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b
                instr.op(instr, state, p)
                state.fieldnum = fieldnum
        }
-       return state.err
+       return nil
 }
 
-func ignoreStruct(engine *decEngine, b **bytes.Buffer) os.Error {
+func ignoreStruct(engine *decEngine, b **bytes.Buffer) (err os.Error) {
+       defer catchError(&err)
        state := newDecodeState(b)
        state.fieldnum = -1
-       for state.b.Len() > 0 && state.err == nil {
+       for state.b.Len() > 0 {
                delta := int(decodeUint(state))
                if delta < 0 {
-                       state.err = os.ErrorString("gob ignore decode: corrupted data: negative delta")
-                       break
+                       errorf("gob ignore decode: corrupted data: negative delta")
                }
-               if state.err != nil || delta == 0 { // struct terminator is zero delta fieldnum
+               if delta == 0 { // struct terminator is zero delta fieldnum
                        break
                }
                fieldnum := state.fieldnum + delta
                if fieldnum >= len(engine.instr) {
-                       state.err = errRange
-                       break
+                       error(errRange)
                }
                instr := &engine.instr[fieldnum]
                instr.op(instr, state, unsafe.Pointer(nil))
                state.fieldnum = fieldnum
        }
-       return state.err
+       return nil
 }
 
-func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) os.Error {
+func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) {
        instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl}
-       for i := 0; i < length && state.err == nil; i++ {
+       for i := 0; i < length; i++ {
                up := unsafe.Pointer(p)
                if elemIndir > 1 {
                        up = decIndirect(up, elemIndir)
@@ -490,17 +486,16 @@ func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uint
                elemOp(instr, state, up)
                p += uintptr(elemWid)
        }
-       return state.err
 }
 
-func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) os.Error {
+func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) {
        if indir > 0 {
                p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect
        }
        if n := decodeUint(state); n != uint64(length) {
-               return os.ErrorString("gob: length mismatch in decodeArray")
+               errorf("gob: length mismatch in decodeArray")
        }
-       return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl)
+       decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl)
 }
 
 func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
@@ -513,7 +508,7 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o
        return v
 }
 
-func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error {
+func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) {
        if indir > 0 {
                p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect
        }
@@ -527,47 +522,38 @@ func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elem
        // the iteration.
        v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue)
        n := int(decodeUint(state))
-       for i := 0; i < n && state.err == nil; i++ {
+       for i := 0; i < n; i++ {
                key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl)
-               if state.err != nil {
-                       break
-               }
                elem := decodeIntoValue(state, elemOp, elemIndir, reflect.MakeZero(mtyp.Elem()), ovfl)
-               if state.err != nil {
-                       break
-               }
                v.SetElem(key, elem)
        }
-       return state.err
 }
 
-func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error {
+func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) {
        instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
-       for i := 0; i < length && state.err == nil; i++ {
+       for i := 0; i < length; i++ {
                elemOp(instr, state, nil)
        }
-       return state.err
 }
 
-func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error {
+func ignoreArray(state *decodeState, elemOp decOp, length int) {
        if n := decodeUint(state); n != uint64(length) {
-               return os.ErrorString("gob: length mismatch in ignoreArray")
+               errorf("gob: length mismatch in ignoreArray")
        }
-       return ignoreArrayHelper(state, elemOp, length)
+       ignoreArrayHelper(state, elemOp, length)
 }
 
-func ignoreMap(state *decodeState, keyOp, elemOp decOp) os.Error {
+func ignoreMap(state *decodeState, keyOp, elemOp decOp) {
        n := int(decodeUint(state))
        keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")}
        elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
-       for i := 0; i < n && state.err == nil; i++ {
+       for i := 0; i < n; i++ {
                keyOp(keyInstr, state, nil)
                elemOp(elemInstr, state, nil)
        }
-       return state.err
 }
 
-func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error {
+func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) {
        n := int(uintptr(decodeUint(state)))
        if indir > 0 {
                up := unsafe.Pointer(p)
@@ -583,11 +569,11 @@ func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp
        hdrp.Data = uintptr(unsafe.NewArray(atyp.Elem(), n))
        hdrp.Len = n
        hdrp.Cap = n
-       return decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl)
+       decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl)
 }
 
-func ignoreSlice(state *decodeState, elemOp decOp) os.Error {
-       return ignoreArrayHelper(state, elemOp, int(decodeUint(state)))
+func ignoreSlice(state *decodeState, elemOp decOp) {
+       ignoreArrayHelper(state, elemOp, int(decodeUint(state)))
 }
 
 // setInterfaceValue sets an interface value to a concrete value through
@@ -596,19 +582,18 @@ func ignoreSlice(state *decodeState, elemOp decOp) os.Error {
 // This dance avoids manually checking that the value satisfies the
 // interface.
 // TODO(rsc): avoid panic+recover after fixing issue 327.
-func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) (err os.Error) {
+func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) {
        defer func() {
                if e := recover(); e != nil {
-                       err = e.(os.Error)
+                       error(e.(os.Error))
                }
        }()
        ivalue.Set(value)
-       return nil
 }
 
 // decodeInterface receives the name of a concrete type followed by its value.
 // If the name is empty, the value is nil and no value is sent.
-func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) os.Error {
+func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) {
        // Create an interface reflect.Value.  We need one even for the nil case.
        ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue)
        // Read the name of the concrete type.
@@ -619,20 +604,18 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt
                // Copy the representation of the nil interface value to the target.
                // This is horribly unsafe and special.
                *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get()
-               return state.err
+               return
        }
        // The concrete type must be registered.
        typ, ok := nameToConcreteType[name]
        if !ok {
-               state.err = os.ErrorString("gob: name not registered for interface: " + name)
-               return state.err
+               errorf("gob: name not registered for interface: %q", name)
        }
        // Read the concrete value.
        value := reflect.MakeZero(typ)
        dec.decodeValueFromBuffer(value, false)
-       if dec.state.err != nil {
-               state.err = dec.state.err
-               return state.err
+       if dec.err != nil {
+               error(dec.err)
        }
        // Allocate the destination interface value.
        if indir > 0 {
@@ -640,26 +623,22 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt
        }
        // Assign the concrete value to the interface.
        // Tread carefully; it might not satisfy the interface.
-       dec.state.err = setInterfaceValue(ivalue, value)
-       if dec.state.err != nil {
-               state.err = dec.state.err
-               return state.err
-       }
+       setInterfaceValue(ivalue, value)
        // Copy the representation of the interface value to the target.
        // This is horribly unsafe and special.
        *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get()
-       return nil
 }
 
-func (dec *Decoder) ignoreInterface(state *decodeState) os.Error {
+func (dec *Decoder) ignoreInterface(state *decodeState) {
        // Read the name of the concrete type.
        b := make([]byte, decodeUint(state))
        _, err := state.b.Read(b)
        if err != nil {
                dec.decodeValueFromBuffer(nil, true)
-               err = dec.state.err
+               if dec.err != nil {
+                       error(err)
+               }
        }
-       return err
 }
 
 // Index by Go types.
@@ -712,7 +691,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
                        }
                        ovfl := overflow(name)
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
+                               decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
                        }
 
                case *reflect.MapType:
@@ -730,7 +709,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
                        ovfl := overflow(name)
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
                                up := unsafe.Pointer(p)
-                               state.err = decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl)
+                               decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl)
                        }
 
                case *reflect.SliceType:
@@ -751,7 +730,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
                        }
                        ovfl := overflow(name)
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
+                               decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
                        }
 
                case *reflect.StructType:
@@ -762,11 +741,14 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
                        }
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
                                // indirect through enginePtr to delay evaluation for recursive structs
-                               state.err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir)
+                               err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir)
+                               if err != nil {
+                                       error(err)
+                               }
                        }
                case *reflect.InterfaceType:
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = dec.decodeInterface(t, state, uintptr(p), i.indir)
+                               dec.decodeInterface(t, state, uintptr(p), i.indir)
                        }
                }
        }
@@ -784,7 +766,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
                        // Special case because it's a method: the ignored item might
                        // define types and we need to record their state in the decoder.
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = dec.ignoreInterface(state)
+                               dec.ignoreInterface(state)
                        }
                        return op, nil
                }
@@ -800,7 +782,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
                                return nil, err
                        }
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = ignoreArray(state, elemOp, wire.arrayT.Len)
+                               ignoreArray(state, elemOp, wire.arrayT.Len)
                        }
 
                case wire.mapT != nil:
@@ -815,7 +797,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
                                return nil, err
                        }
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = ignoreMap(state, keyOp, elemOp)
+                               ignoreMap(state, keyOp, elemOp)
                        }
 
                case wire.sliceT != nil:
@@ -825,7 +807,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
                                return nil, err
                        }
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
-                               state.err = ignoreSlice(state, elemOp)
+                               ignoreSlice(state, elemOp)
                        }
 
                case wire.structT != nil:
@@ -836,7 +818,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
                        }
                        op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
                                // indirect through enginePtr to delay evaluation for recursive structs
-                               state.err = ignoreStruct(*enginePtr, state.b)
+                               ignoreStruct(*enginePtr, state.b)
                        }
                }
        }
index e2f1e363f697301ff1e63448e7afbe6e3d5d3caa..b86bdf3985e3631a50c540ec1a13aa109ac864f4 100644 (file)
@@ -25,6 +25,7 @@ type Decoder struct {
        buf          []byte
        countBuf     [9]byte // counts may be uint64s (unlikely!), require 9 bytes
        byteBuffer   *bytes.Buffer
+       err          os.Error
 }
 
 // NewDecoder returns a new decoder that reads from the io.Reader.
@@ -43,13 +44,16 @@ func NewDecoder(r io.Reader) *Decoder {
 func (dec *Decoder) recvType(id typeId) {
        // Have we already seen this type?  That's an error
        if dec.wireType[id] != nil {
-               dec.state.err = os.ErrorString("gob: duplicate type received")
+               dec.err = os.ErrorString("gob: duplicate type received")
                return
        }
 
        // Type:
        wire := new(wireType)
-       dec.state.err = dec.decode(tWireType, reflect.NewValue(wire))
+       dec.err = dec.decode(tWireType, reflect.NewValue(wire))
+       if dec.err != nil {
+               return
+       }
        // Remember we've seen this type.
        dec.wireType[id] = wire
 
@@ -66,8 +70,8 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
        // If e represents a value as opposed to a pointer, the answer won't
        // get back to the caller.  Make sure it's a pointer.
        if value.Type().Kind() != reflect.Ptr {
-               dec.state.err = os.ErrorString("gob: attempt to decode into a non-pointer")
-               return dec.state.err
+               dec.err = os.ErrorString("gob: attempt to decode into a non-pointer")
+               return dec.err
        }
        return dec.DecodeValue(value)
 }
@@ -77,8 +81,8 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
 func (dec *Decoder) recv() {
        // Read a count.
        var nbytes uint64
-       nbytes, dec.state.err = decodeUintReader(dec.r, dec.countBuf[0:])
-       if dec.state.err != nil {
+       nbytes, dec.err = decodeUintReader(dec.r, dec.countBuf[0:])
+       if dec.err != nil {
                return
        }
        // Allocate the buffer.
@@ -88,10 +92,10 @@ func (dec *Decoder) recv() {
        dec.byteBuffer = bytes.NewBuffer(dec.buf[0:nbytes])
 
        // Read the data
-       _, dec.state.err = io.ReadFull(dec.r, dec.buf[0:nbytes])
-       if dec.state.err != nil {
-               if dec.state.err == os.EOF {
-                       dec.state.err = io.ErrUnexpectedEOF
+       _, dec.err = io.ReadFull(dec.r, dec.buf[0:nbytes])
+       if dec.err != nil {
+               if dec.err == os.EOF {
+                       dec.err = io.ErrUnexpectedEOF
                }
                return
        }
@@ -104,7 +108,7 @@ func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignore bool) {
        for dec.state.b.Len() > 0 {
                // Receive a type id.
                id := typeId(decodeInt(dec.state))
-               if dec.state.err != nil {
+               if dec.err != nil {
                        break
                }
 
@@ -112,7 +116,7 @@ func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignore bool) {
                if id < 0 { // 0 is the error state, handled above
                        // If the id is negative, we have a type.
                        dec.recvType(-id)
-                       if dec.state.err != nil {
+                       if dec.err != nil {
                                break
                        }
                        continue
@@ -126,10 +130,10 @@ func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignore bool) {
                // Make sure the type has been defined already or is a builtin type (for
                // top-level singleton values).
                if dec.wireType[id] == nil && builtinIdToType[id] == nil {
-                       dec.state.err = errBadType
+                       dec.err = errBadType
                        break
                }
-               dec.state.err = dec.decode(id, value)
+               dec.err = dec.decode(id, value)
                break
        }
 }
@@ -143,11 +147,11 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
        dec.mutex.Lock()
        defer dec.mutex.Unlock()
 
-       dec.state.err = nil
+       dec.err = nil
        dec.recv()
-       if dec.state.err != nil {
-               return dec.state.err
+       if dec.err != nil {
+               return dec.err
        }
        dec.decodeValueFromBuffer(value, false)
-       return dec.state.err
+       return dec.err
 }
index 833b87c76728d4cc52190827b0ddc6e002345e3b..0be2d81a5a34638e934d8bd3b3aa02f13e9f27d4 100644 (file)
@@ -21,7 +21,6 @@ const uint64Size = unsafe.Sizeof(uint64(0))
 // 0 terminates the structure.
 type encoderState struct {
        b        *bytes.Buffer
-       err      os.Error             // error encountered during encoding.
        sendZero bool                 // encoding an array element or map key/value pair; send zero values
        fieldnum int                  // the last field number written.
        buf      [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation.
@@ -39,11 +38,11 @@ func newEncoderState(b *bytes.Buffer) *encoderState {
 // encodeUint writes an encoded unsigned integer to state.b.  Sets state.err.
 // If state.err is already non-nil, it does nothing.
 func encodeUint(state *encoderState, x uint64) {
-       if state.err != nil {
-               return
-       }
        if x <= 0x7F {
-               state.err = state.b.WriteByte(uint8(x))
+               err := state.b.WriteByte(uint8(x))
+               if err != nil {
+                       error(err)
+               }
                return
        }
        var n, m int
@@ -54,7 +53,10 @@ func encodeUint(state *encoderState, x uint64) {
                m--
        }
        state.buf[m] = uint8(-(n - 1))
-       n, state.err = state.b.Write(state.buf[m : uint64Size+1])
+       n, err := state.b.Write(state.buf[m : uint64Size+1])
+       if err != nil {
+               error(err)
+       }
 }
 
 // encodeInt writes an encoded signed integer to state.w.
@@ -317,7 +319,8 @@ type encEngine struct {
 
 const singletonField = 0
 
-func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
+func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Error) {
+       defer catchError(&err)
        state := newEncoderState(b)
        state.fieldnum = singletonField
        // There is no surrounding struct to frame the transmission, so we must
@@ -331,10 +334,11 @@ func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
                }
        }
        instr.op(instr, state, p)
-       return state.err
+       return
 }
 
-func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
+func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Error) {
+       defer catchError(&err)
        state := newEncoderState(b)
        state.fieldnum = -1
        for i := 0; i < len(engine.instr); i++ {
@@ -346,32 +350,27 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error {
                        }
                }
                instr.op(instr, state, p)
-               if state.err != nil {
-                       break
-               }
        }
-       return state.err
+       return nil
 }
 
-func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) os.Error {
+func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) {
        state := newEncoderState(b)
        state.fieldnum = -1
        state.sendZero = true
        encodeUint(state, uint64(length))
-       for i := 0; i < length && state.err == nil; i++ {
+       for i := 0; i < length; i++ {
                elemp := p
                up := unsafe.Pointer(elemp)
                if elemIndir > 0 {
                        if up = encIndirect(up, elemIndir); up == nil {
-                               state.err = os.ErrorString("gob: encodeArray: nil element")
-                               break
+                               errorf("gob: encodeArray: nil element")
                        }
                        elemp = uintptr(up)
                }
                op(nil, state, unsafe.Pointer(elemp))
                p += uintptr(elemWid)
        }
-       return state.err
 }
 
 func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) {
@@ -379,61 +378,54 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in
                v = reflect.Indirect(v)
        }
        if v == nil {
-               state.err = os.ErrorString("gob: encodeReflectValue: nil element")
-               return
+               errorf("gob: encodeReflectValue: nil element")
        }
        op(nil, state, unsafe.Pointer(v.Addr()))
 }
 
-func encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error {
+func encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) {
        state := newEncoderState(b)
        state.fieldnum = -1
        state.sendZero = true
        keys := mv.Keys()
        encodeUint(state, uint64(len(keys)))
        for _, key := range keys {
-               if state.err != nil {
-                       break
-               }
                encodeReflectValue(state, key, keyOp, keyIndir)
                encodeReflectValue(state, mv.Elem(key), elemOp, elemIndir)
        }
-       return state.err
 }
 
 // To send an interface, we send a string identifying the concrete type, followed
 // by the type identifier (which might require defining that type right now), followed
 // by the concrete value.  A nil value gets sent as the empty string for the name,
 // followed by no value.
-func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) os.Error {
+func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) {
        state := newEncoderState(b)
        state.fieldnum = -1
        state.sendZero = true
        if iv.IsNil() {
                encodeUint(state, 0)
-               return state.err
+               return
        }
 
        typ := iv.Elem().Type()
        name, ok := concreteTypeToName[typ]
        if !ok {
-               state.err = os.ErrorString("gob: type not registered for interface: " + typ.String())
-               return state.err
+               errorf("gob: type not registered for interface: %s", typ)
        }
        // Send the name.
        encodeUint(state, uint64(len(name)))
-       _, state.err = io.WriteString(state.b, name)
-       if state.err != nil {
-               return state.err
+       _, err := io.WriteString(state.b, name)
+       if err != nil {
+               error(err)
        }
        // Send (and maybe first define) the type id.
        enc.sendTypeDescriptor(typ)
-       if state.err != nil {
-               return state.err
-       }
        // Send the value.
-       state.err = enc.encode(state.b, iv.Elem())
-       return state.err
+       err = enc.encode(state.b, iv.Elem())
+       if err != nil {
+               error(err)
+       }
 }
 
 var encOpMap = []encOp{
@@ -486,7 +478,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                                        return
                                }
                                state.update(i)
-                               state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len))
+                               encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len))
                        }
                case *reflect.ArrayType:
                        // True arrays have size in the type.
@@ -496,7 +488,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                        }
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
                                state.update(i)
-                               state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
+                               encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
                        }
                case *reflect.MapType:
                        keyOp, keyIndir, err := enc.encOpFor(t.Key())
@@ -517,7 +509,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                                        return
                                }
                                state.update(i)
-                               state.err = encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir)
+                               encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir)
                        }
                case *reflect.StructType:
                        // Generate a closure that calls out to the engine for the nested type.
@@ -529,7 +521,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
                                state.update(i)
                                // indirect through info to delay evaluation for recursive structs
-                               state.err = encodeStruct(info.encoder, state.b, uintptr(p))
+                               encodeStruct(info.encoder, state.b, uintptr(p))
                        }
                case *reflect.InterfaceType:
                        op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
@@ -541,7 +533,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                                        return
                                }
                                state.update(i)
-                               state.err = enc.encodeInterface(state.b, iv)
+                               enc.encodeInterface(state.b, iv)
                        }
                }
        }
index ff9834600f30f3915662c6c68183c50ff11b5135..5d12d920b4164aca40a198be6ef9f5db0b74cd35 100644 (file)
@@ -21,6 +21,7 @@ type Encoder struct {
        state      *encoderState           // so we can encode integers, strings directly
        countState *encoderState           // stage for writing counts
        buf        []byte                  // for collecting the output.
+       err        os.Error
 }
 
 // NewEncoder returns a new encoder that will transmit on the io.Writer.
@@ -38,8 +39,8 @@ func (enc *Encoder) badType(rt reflect.Type) {
 }
 
 func (enc *Encoder) setError(err os.Error) {
-       if enc.state.err == nil { // remember the first.
-               enc.state.err = err
+       if enc.err == nil { // remember the first.
+               enc.err = err
        }
        enc.state.b.Reset()
 }
@@ -115,7 +116,7 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
        // Type:
        enc.encode(enc.state.b, reflect.NewValue(info.wire))
        enc.send()
-       if enc.state.err != nil {
+       if enc.err != nil {
                return
        }
 
@@ -150,7 +151,7 @@ func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) {
        if _, alreadySent := enc.sent[rt]; !alreadySent {
                // No, so send it.
                sent := enc.sendType(rt)
-               if enc.state.err != nil {
+               if enc.err != nil {
                        return
                }
                // If the type info has still not been transmitted, it means we have
@@ -180,18 +181,18 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
        enc.mutex.Lock()
        defer enc.mutex.Unlock()
 
-       enc.state.err = nil
+       enc.err = nil
        rt, _ := indirect(value.Type())
 
        // Sanity check only: encoder should never come in with data present.
        if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 {
-               enc.state.err = os.ErrorString("encoder: buffer not empty")
-               return enc.state.err
+               enc.err = os.ErrorString("encoder: buffer not empty")
+               return enc.err
        }
 
        enc.sendTypeDescriptor(rt)
-       if enc.state.err != nil {
-               return enc.state.err
+       if enc.err != nil {
+               return enc.err
        }
 
        // Encode the object.
@@ -202,5 +203,5 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
                enc.send()
        }
 
-       return enc.state.err
+       return enc.err
 }
index 4f2702a4dd5cafbef11429c5d01c16baaf0203cc..91d85bb7ad65b56368669acadb69034a154872b8 100644 (file)
@@ -43,15 +43,15 @@ func TestEncoderDecoder(t *testing.T) {
        et1 := new(ET1)
        et1.a = 7
        et1.et2 = new(ET2)
-       enc.Encode(et1)
-       if enc.state.err != nil {
-               t.Error("encoder fail:", enc.state.err)
+       err := enc.Encode(et1)
+       if err != nil {
+               t.Error("encoder fail:", err)
        }
        dec := NewDecoder(b)
        newEt1 := new(ET1)
-       dec.Decode(newEt1)
-       if dec.state.err != nil {
-               t.Fatal("error decoding ET1:", dec.state.err)
+       err = dec.Decode(newEt1)
+       if err != nil {
+               t.Fatal("error decoding ET1:", err)
        }
 
        if !reflect.DeepEqual(et1, newEt1) {
@@ -63,9 +63,9 @@ func TestEncoderDecoder(t *testing.T) {
 
        enc.Encode(et1)
        newEt1 = new(ET1)
-       dec.Decode(newEt1)
-       if dec.state.err != nil {
-               t.Fatal("round 2: error decoding ET1:", dec.state.err)
+       err = dec.Decode(newEt1)
+       if err != nil {
+               t.Fatal("round 2: error decoding ET1:", err)
        }
        if !reflect.DeepEqual(et1, newEt1) {
                t.Fatalf("round 2: invalid data for et1: expected %+v; got %+v", *et1, *newEt1)
@@ -75,13 +75,13 @@ func TestEncoderDecoder(t *testing.T) {
        }
 
        // Now test with a running encoder/decoder pair that we recognize a type mismatch.
-       enc.Encode(et1)
-       if enc.state.err != nil {
-               t.Error("round 3: encoder fail:", enc.state.err)
+       err = enc.Encode(et1)
+       if err != nil {
+               t.Error("round 3: encoder fail:", err)
        }
        newEt2 := new(ET2)
-       dec.Decode(newEt2)
-       if dec.state.err == nil {
+       err = dec.Decode(newEt2)
+       if err == nil {
                t.Fatal("round 3: expected `bad type' error decoding ET2")
        }
 }
@@ -94,17 +94,17 @@ func badTypeCheck(e interface{}, shouldFail bool, msg string, t *testing.T) {
        et1 := new(ET1)
        et1.a = 7
        et1.et2 = new(ET2)
-       enc.Encode(et1)
-       if enc.state.err != nil {
-               t.Error("encoder fail:", enc.state.err)
+       err := enc.Encode(et1)
+       if err != nil {
+               t.Error("encoder fail:", err)
        }
        dec := NewDecoder(b)
-       dec.Decode(e)
-       if shouldFail && (dec.state.err == nil) {
+       err = dec.Decode(e)
+       if shouldFail && err == nil {
                t.Error("expected error for", msg)
        }
-       if !shouldFail && (dec.state.err != nil) {
-               t.Error("unexpected error for", msg, dec.state.err)
+       if !shouldFail && err != nil {
+               t.Error("unexpected error for", msg, err)
        }
 }
 
@@ -118,9 +118,9 @@ func TestWrongTypeDecoder(t *testing.T) {
 func corruptDataCheck(s string, err os.Error, t *testing.T) {
        b := bytes.NewBufferString(s)
        dec := NewDecoder(b)
-       dec.Decode(new(ET2))
-       if dec.state.err != err {
-               t.Error("expected error", err, "got", dec.state.err)
+       err1 := dec.Decode(new(ET2))
+       if err1 != err {
+               t.Error("expected error", err, "got", err1)
        }
 }
 
@@ -151,14 +151,14 @@ func TestUnsupported(t *testing.T) {
 func encAndDec(in, out interface{}) os.Error {
        b := new(bytes.Buffer)
        enc := NewEncoder(b)
-       enc.Encode(in)
-       if enc.state.err != nil {
-               return enc.state.err
+       err := enc.Encode(in)
+       if err != nil {
+               return err
        }
        dec := NewDecoder(b)
-       dec.Decode(out)
-       if dec.state.err != nil {
-               return dec.state.err
+       err = dec.Decode(out)
+       if err != nil {
+               return err
        }
        return nil
 }
diff --git a/src/pkg/gob/error.go b/src/pkg/gob/error.go
new file mode 100644 (file)
index 0000000..b053761
--- /dev/null
@@ -0,0 +1,41 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+       "fmt"
+       "os"
+)
+
+// Errors in decoding and encoding are handled using panic and recover.
+// Panics caused by user error (that is, everything except run-time panics
+// such as "index out of bounds" errors) do not leave the file that caused
+// them, but are instead turned into plain os.Error returns.  Encoding and
+// decoding functions and methods that do not return an os.Error either use
+// panic to report an error or are guaranteed error-free.
+
+// A gobError wraps an os.Error and is used to distinguish errors (panics) generated in this package.
+type gobError struct {
+       os.Error
+}
+
+// errorf is like error but takes Printf-style arguments to construct an os.Error.
+func errorf(format string, args ...interface{}) {
+       error(fmt.Errorf(format, args...))
+}
+
+// error wraps the argument error and uses it as the argument to panic.
+func error(err os.Error) {
+       panic(gobError{Error: err})
+}
+
+// catchError is meant to be used as a deferred function to turn a panic(gobError) into a
+// plain os.Error.  It overwrites the error return of the function that deferred its call.
+func catchError(err *os.Error) {
+       if e := recover(); e != nil {
+               *err = e.(gobError).Error // Will re-panic if not one of our errors, such as a runtime error.
+       }
+       return
+}