From e76a335adacb2f9dd3a930a53eb9ab4a7c766209 Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Wed, 15 Jul 2009 16:10:17 -0700 Subject: [PATCH] make the low-level encoder and decoder private and have them access byte.Buffers rather than io.Readers and io.Writers. change the Encoder/Decoder protocol so that each message is preceded by its length in bytes. R=rsc DELTA=468 (119 added, 23 deleted, 326 changed) OCL=31700 CL=31702 --- src/pkg/Make.deps | 2 +- src/pkg/gob/codec_test.go | 142 ++++++++++++++++----------------- src/pkg/gob/decode.go | 152 ++++++++++++++++++++---------------- src/pkg/gob/decoder.go | 81 ++++++++++++------- src/pkg/gob/encode.go | 136 ++++++++++++++++---------------- src/pkg/gob/encoder.go | 53 ++++++++++--- src/pkg/gob/encoder_test.go | 50 ++++++++---- src/pkg/gob/type.go | 10 +-- 8 files changed, 361 insertions(+), 265 deletions(-) diff --git a/src/pkg/Make.deps b/src/pkg/Make.deps index 7bbb22eb66..532ede1468 100644 --- a/src/pkg/Make.deps +++ b/src/pkg/Make.deps @@ -24,7 +24,7 @@ go/parser.install: bytes.install container/vector.install fmt.install go/ast.ins go/printer.install: fmt.install go/ast.install go/token.install io.install os.install reflect.install strings.install go/scanner.install: bytes.install container/vector.install fmt.install go/token.install io.install os.install sort.install strconv.install unicode.install utf8.install go/token.install: strconv.install -gob.install: fmt.install io.install math.install os.install reflect.install strings.install sync.install unicode.install +gob.install: bytes.install fmt.install io.install math.install os.install reflect.install strings.install sync.install unicode.install hash.install: io.install hash/adler32.install: hash.install os.install hash/crc32.install: hash.install os.install diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index 23ff885f0e..44ebf3a113 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -39,27 +39,27 @@ var encodeT = []EncodeT { // Test basic encode/decode routines for unsigned integers func TestUintCodec(t *testing.T) { b := new(bytes.Buffer); - encState := new(EncState); - encState.w = b; + encState := new(encoderState); + encState.b = b; for i, tt := range encodeT { b.Reset(); - EncodeUint(encState, tt.x); + encodeUint(encState, tt.x); if encState.err != nil { - t.Error("EncodeUint:", tt.x, encState.err) + t.Error("encodeUint:", tt.x, encState.err) } if !bytes.Equal(tt.b, b.Data()) { - t.Errorf("EncodeUint: expected % x got % x", tt.b, b.Data()) + t.Errorf("encodeUint: expected % x got % x", tt.b, b.Data()) } } - decState := new(DecState); - decState.r = b; + decState := new(decodeState); + decState.b = b; for u := uint64(0); ; u = (u+1) * 7 { b.Reset(); - EncodeUint(encState, u); + encodeUint(encState, u); if encState.err != nil { - t.Error("EncodeUint:", u, encState.err) + t.Error("encodeUint:", u, encState.err) } - v := DecodeUint(decState); + v := decodeUint(decState); if decState.err != nil { t.Error("DecodeUint:", u, decState.err) } @@ -74,15 +74,15 @@ func TestUintCodec(t *testing.T) { func verifyInt(i int64, t *testing.T) { var b = new(bytes.Buffer); - encState := new(EncState); - encState.w = b; - EncodeInt(encState, i); + encState := new(encoderState); + encState.b = b; + encodeInt(encState, i); if encState.err != nil { - t.Error("EncodeInt:", i, encState.err) + t.Error("encodeInt:", i, encState.err) } - decState := new(DecState); - decState.r = b; - j := DecodeInt(decState); + decState := new(decodeState); + decState.b = b; + j := decodeInt(decState); if decState.err != nil { t.Error("DecodeInt:", i, decState.err) } @@ -116,10 +116,10 @@ var floatResult = []byte{0x87, 0x40, 0xe2} // The result of encoding "hello" with field number 6 var bytesResult = []byte{0x87, 0x85, 'h', 'e', 'l', 'l', 'o'} -func newEncState(b *bytes.Buffer) *EncState { +func newencoderState(b *bytes.Buffer) *encoderState { b.Reset(); - state := new(EncState); - state.w = b; + state := new(encoderState); + state.b = b; state.fieldnum = -1; return state; } @@ -133,7 +133,7 @@ func TestScalarEncInstructions(t *testing.T) { { data := struct { a bool } { true }; instr := &encInstr{ encBool, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(boolResult, b.Data()) { t.Errorf("bool enc instructions: expected % x got % x", boolResult, b.Data()) @@ -145,7 +145,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a int } { 17 }; instr := &encInstr{ encInt, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(signedResult, b.Data()) { t.Errorf("int enc instructions: expected % x got % x", signedResult, b.Data()) @@ -157,7 +157,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a uint } { 17 }; instr := &encInstr{ encUint, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(unsignedResult, b.Data()) { t.Errorf("uint enc instructions: expected % x got % x", unsignedResult, b.Data()) @@ -169,7 +169,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a int8 } { 17 }; instr := &encInstr{ encInt, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(signedResult, b.Data()) { t.Errorf("int8 enc instructions: expected % x got % x", signedResult, b.Data()) @@ -181,7 +181,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a uint8 } { 17 }; instr := &encInstr{ encUint, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(unsignedResult, b.Data()) { t.Errorf("uint8 enc instructions: expected % x got % x", unsignedResult, b.Data()) @@ -196,7 +196,7 @@ func TestScalarEncInstructions(t *testing.T) { ppv := &pv; data := struct { a int16 } { 17 }; instr := &encInstr{ encInt16, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(signedResult, b.Data()) { t.Errorf("int16 enc instructions: expected % x got % x", signedResult, b.Data()) @@ -208,7 +208,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a uint16 } { 17 }; instr := &encInstr{ encUint16, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(unsignedResult, b.Data()) { t.Errorf("uint16 enc instructions: expected % x got % x", unsignedResult, b.Data()) @@ -220,7 +220,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a int32 } { 17 }; instr := &encInstr{ encInt32, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(signedResult, b.Data()) { t.Errorf("int32 enc instructions: expected % x got % x", signedResult, b.Data()) @@ -232,7 +232,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a uint32 } { 17 }; instr := &encInstr{ encUint32, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(unsignedResult, b.Data()) { t.Errorf("uint32 enc instructions: expected % x got % x", unsignedResult, b.Data()) @@ -244,7 +244,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a int64 } { 17 }; instr := &encInstr{ encInt64, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(signedResult, b.Data()) { t.Errorf("int64 enc instructions: expected % x got % x", signedResult, b.Data()) @@ -256,7 +256,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a uint64 } { 17 }; instr := &encInstr{ encUint, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(unsignedResult, b.Data()) { t.Errorf("uint64 enc instructions: expected % x got % x", unsignedResult, b.Data()) @@ -268,7 +268,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a float } { 17 }; instr := &encInstr{ encFloat, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(floatResult, b.Data()) { t.Errorf("float enc instructions: expected % x got % x", floatResult, b.Data()) @@ -280,7 +280,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a float32 } { 17 }; instr := &encInstr{ encFloat32, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(floatResult, b.Data()) { t.Errorf("float32 enc instructions: expected % x got % x", floatResult, b.Data()) @@ -292,7 +292,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a float64 } { 17 }; instr := &encInstr{ encFloat64, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(floatResult, b.Data()) { t.Errorf("float64 enc instructions: expected % x got % x", floatResult, b.Data()) @@ -304,7 +304,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a []byte } { strings.Bytes("hello") }; instr := &encInstr{ encUint8Array, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(bytesResult, b.Data()) { t.Errorf("bytes enc instructions: expected % x got % x", bytesResult, b.Data()) @@ -316,7 +316,7 @@ func TestScalarEncInstructions(t *testing.T) { b.Reset(); data := struct { a string } { "hello" }; instr := &encInstr{ encString, 6, 0, 0 }; - state := newEncState(b); + state := newencoderState(b); instr.op(instr, state, unsafe.Pointer(&data)); if !bytes.Equal(bytesResult, b.Data()) { t.Errorf("string enc instructions: expected % x got % x", bytesResult, b.Data()) @@ -324,8 +324,8 @@ func TestScalarEncInstructions(t *testing.T) { } } -func execDec(typ string, instr *decInstr, state *DecState, t *testing.T, p unsafe.Pointer) { - v := int(DecodeUint(state)); +func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) { + v := int(decodeUint(state)); if state.err != nil { t.Fatalf("decoding %s field: %v", typ, state.err); } @@ -336,9 +336,9 @@ func execDec(typ string, instr *decInstr, state *DecState, t *testing.T, p unsaf state.fieldnum = 6; } -func newDecState(data []byte) *DecState { - state := new(DecState); - state.r = bytes.NewBuffer(data); +func newdecodeState(data []byte) *decodeState { + state := new(decodeState); + state.b = bytes.NewBuffer(data); state.fieldnum = -1; return state; } @@ -351,7 +351,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a bool }; instr := &decInstr{ decBool, 6, 0, 0 }; - state := newDecState(boolResult); + state := newdecodeState(boolResult); execDec("bool", instr, state, t, unsafe.Pointer(&data)); if data.a != true { t.Errorf("int a = %v not true", data.a) @@ -361,7 +361,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a int }; instr := &decInstr{ decInt, 6, 0, 0 }; - state := newDecState(signedResult); + state := newdecodeState(signedResult); execDec("int", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -372,7 +372,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a uint }; instr := &decInstr{ decUint, 6, 0, 0 }; - state := newDecState(unsignedResult); + state := newdecodeState(unsignedResult); execDec("uint", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -383,7 +383,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a int8 }; instr := &decInstr{ decInt8, 6, 0, 0 }; - state := newDecState(signedResult); + state := newdecodeState(signedResult); execDec("int8", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -394,7 +394,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a uint8 }; instr := &decInstr{ decUint8, 6, 0, 0 }; - state := newDecState(unsignedResult); + state := newdecodeState(unsignedResult); execDec("uint8", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -405,7 +405,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a int16 }; instr := &decInstr{ decInt16, 6, 0, 0 }; - state := newDecState(signedResult); + state := newdecodeState(signedResult); execDec("int16", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -416,7 +416,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a uint16 }; instr := &decInstr{ decUint16, 6, 0, 0 }; - state := newDecState(unsignedResult); + state := newdecodeState(unsignedResult); execDec("uint16", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -427,7 +427,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a int32 }; instr := &decInstr{ decInt32, 6, 0, 0 }; - state := newDecState(signedResult); + state := newdecodeState(signedResult); execDec("int32", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -438,7 +438,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a uint32 }; instr := &decInstr{ decUint32, 6, 0, 0 }; - state := newDecState(unsignedResult); + state := newdecodeState(unsignedResult); execDec("uint32", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -449,7 +449,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a int64 }; instr := &decInstr{ decInt64, 6, 0, 0 }; - state := newDecState(signedResult); + state := newdecodeState(signedResult); execDec("int64", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -460,7 +460,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a uint64 }; instr := &decInstr{ decUint64, 6, 0, 0 }; - state := newDecState(unsignedResult); + state := newdecodeState(unsignedResult); execDec("uint64", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -471,7 +471,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a float }; instr := &decInstr{ decFloat, 6, 0, 0 }; - state := newDecState(floatResult); + state := newdecodeState(floatResult); execDec("float", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -482,7 +482,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a float32 }; instr := &decInstr{ decFloat32, 6, 0, 0 }; - state := newDecState(floatResult); + state := newdecodeState(floatResult); execDec("float32", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -493,7 +493,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a float64 }; instr := &decInstr{ decFloat64, 6, 0, 0 }; - state := newDecState(floatResult); + state := newdecodeState(floatResult); execDec("float64", instr, state, t, unsafe.Pointer(&data)); if data.a != 17 { t.Errorf("int a = %v not 17", data.a) @@ -504,7 +504,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a []byte }; instr := &decInstr{ decUint8Array, 6, 0, 0 }; - state := newDecState(bytesResult); + state := newdecodeState(bytesResult); execDec("bytes", instr, state, t, unsafe.Pointer(&data)); if string(data.a) != "hello" { t.Errorf(`bytes a = %q not "hello"`, string(data.a)) @@ -515,7 +515,7 @@ func TestScalarDecInstructions(t *testing.T) { { var data struct { a string }; instr := &decInstr{ decString, 6, 0, 0 }; - state := newDecState(bytesResult); + state := newdecodeState(bytesResult); execDec("bytes", instr, state, t, unsafe.Pointer(&data)); if data.a != "hello" { t.Errorf(`bytes a = %q not "hello"`, data.a) @@ -551,9 +551,9 @@ func TestEndToEnd(t *testing.T) { t: &T2{"this is T2"}, }; b := new(bytes.Buffer); - Encode(b, t1); + encode(b, t1); var _t1 T1; - Decode(b, &_t1); + decode(b, &_t1); if !reflect.DeepEqual(t1, &_t1) { t.Errorf("encode expected %v got %v", *t1, _t1); } @@ -569,9 +569,9 @@ func TestNesting(t *testing.T) { rt.next = new(RT); rt.next.a = "level2"; b := new(bytes.Buffer); - Encode(b, rt); + encode(b, rt); var drt RT; - Decode(b, &drt); + decode(b, &drt); if drt.a != rt.a { t.Errorf("nesting: encode expected %v got %v", *rt, drt); } @@ -611,9 +611,9 @@ func TestAutoIndirection(t *testing.T) { t1.c = new(*int); *t1.c = new(int); **t1.c = 1777; t1.d = new(**int); *t1.d = new(*int); **t1.d = new(int); ***t1.d = 17777; b := new(bytes.Buffer); - Encode(b, t1); + encode(b, t1); var t0 T0; - Decode(b, &t0); + decode(b, &t0); if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 { t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0); } @@ -625,9 +625,9 @@ func TestAutoIndirection(t *testing.T) { t2.b = new(*int); *t2.b = new(int); **t2.b = 177; t2.a = new(**int); *t2.a = new(*int); **t2.a = new(int); ***t2.a = 17; b.Reset(); - Encode(b, t2); + encode(b, t2); t0 = T0{}; - Decode(b, &t0); + decode(b, &t0); if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 { t.Errorf("t2->t0 expected {17 177 1777 17777}; got %v", t0); } @@ -635,30 +635,30 @@ func TestAutoIndirection(t *testing.T) { // Now transfer t0 into t1 t0 = T0{17, 177, 1777, 17777}; b.Reset(); - Encode(b, t0); + encode(b, t0); t1 = T1{}; - Decode(b, &t1); + decode(b, &t1); if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 { t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d); } // Now transfer t0 into t2 b.Reset(); - Encode(b, t0); + encode(b, t0); t2 = T2{}; - Decode(b, &t2); + decode(b, &t2); if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 { t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d); } // Now do t2 again but without pre-allocated pointers. b.Reset(); - Encode(b, t0); + encode(b, t0); ***t2.a = 0; **t2.b = 0; *t2.c = 0; t2.d = 0; - Decode(b, &t2); + decode(b, &t2); if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 { t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d); } diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 4735f6ba1c..ec2ed66a5a 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -8,6 +8,7 @@ package gob // the allocations in this file that use unsafe.Pointer. import ( + "bytes"; "gob"; "io"; "math"; @@ -17,27 +18,44 @@ import ( ) // The global execution state of an instance of the decoder. -type DecState struct { - r io.Reader; +type decodeState struct { + b *bytes.Buffer; err os.Error; fieldnum int; // the last field number read. - buf [1]byte; // buffer used by the decoder; here to avoid allocation. } -// DecodeUint reads an encoded unsigned integer from state.r. +// decodeUintReader reads an encoded unsigned integer from an io.Reader. +// Used only by the Decoder to read the message length. +func decodeUintReader(r io.Reader, oneByte []byte) (x uint64, err os.Error) { + for shift := uint(0);; shift += 7 { + var n int; + n, err = r.Read(oneByte); + if err != nil { + return 0, err + } + b := oneByte[0]; + x |= uint64(b) << shift; + if b&0x80 != 0 { + x &^= 0x80 << shift; + break + } + } + return x, nil; +} + +// decodeUint reads an encoded unsigned integer from state.r. // Sets state.err. If state.err is already non-nil, it does nothing. -func DecodeUint(state *DecState) (x uint64) { +func decodeUint(state *decodeState) (x uint64) { if state.err != nil { return } for shift := uint(0);; shift += 7 { - var n int; - n, state.err = state.r.Read(&state.buf); - if n != 1 { + var b uint8; + b, state.err = state.b.ReadByte(); + if state.err != nil { return 0 } - b := uint64(state.buf[0]); - x |= b << shift; + x |= uint64(b) << shift; if b&0x80 != 0 { x &^= 0x80 << shift; break @@ -46,10 +64,10 @@ func DecodeUint(state *DecState) (x uint64) { return x; } -// DecodeInt reads an encoded signed integer from state.r. +// decodeInt reads an encoded signed integer from state.r. // Sets state.err. If state.err is already non-nil, it does nothing. -func DecodeInt(state *DecState) int64 { - x := DecodeUint(state); +func decodeInt(state *decodeState) int64 { + x := decodeUint(state); if state.err != nil { return 0 } @@ -60,7 +78,7 @@ func DecodeInt(state *DecState) int64 { } type decInstr struct -type decOp func(i *decInstr, state *DecState, p unsafe.Pointer); +type decOp func(i *decInstr, state *decodeState, p unsafe.Pointer); // The 'instructions' of the decoding machine type decInstr struct { @@ -89,124 +107,124 @@ func decIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { return p } -func decBool(i *decInstr, state *DecState, p unsafe.Pointer) { +func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(bool)); } p = *(*unsafe.Pointer)(p); } - *(*bool)(p) = DecodeInt(state) != 0; + *(*bool)(p) = decodeInt(state) != 0; } -func decInt(i *decInstr, state *DecState, p unsafe.Pointer) { +func decInt(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int)); } p = *(*unsafe.Pointer)(p); } - *(*int)(p) = int(DecodeInt(state)); + *(*int)(p) = int(decodeInt(state)); } -func decUint(i *decInstr, state *DecState, p unsafe.Pointer) { +func decUint(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint)); } p = *(*unsafe.Pointer)(p); } - *(*uint)(p) = uint(DecodeUint(state)); + *(*uint)(p) = uint(decodeUint(state)); } -func decInt8(i *decInstr, state *DecState, p unsafe.Pointer) { +func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int8)); } p = *(*unsafe.Pointer)(p); } - *(*int8)(p) = int8(DecodeInt(state)); + *(*int8)(p) = int8(decodeInt(state)); } -func decUint8(i *decInstr, state *DecState, p unsafe.Pointer) { +func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint8)); } p = *(*unsafe.Pointer)(p); } - *(*uint8)(p) = uint8(DecodeUint(state)); + *(*uint8)(p) = uint8(decodeUint(state)); } -func decInt16(i *decInstr, state *DecState, p unsafe.Pointer) { +func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int16)); } p = *(*unsafe.Pointer)(p); } - *(*int16)(p) = int16(DecodeInt(state)); + *(*int16)(p) = int16(decodeInt(state)); } -func decUint16(i *decInstr, state *DecState, p unsafe.Pointer) { +func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint16)); } p = *(*unsafe.Pointer)(p); } - *(*uint16)(p) = uint16(DecodeUint(state)); + *(*uint16)(p) = uint16(decodeUint(state)); } -func decInt32(i *decInstr, state *DecState, p unsafe.Pointer) { +func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int32)); } p = *(*unsafe.Pointer)(p); } - *(*int32)(p) = int32(DecodeInt(state)); + *(*int32)(p) = int32(decodeInt(state)); } -func decUint32(i *decInstr, state *DecState, p unsafe.Pointer) { +func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint32)); } p = *(*unsafe.Pointer)(p); } - *(*uint32)(p) = uint32(DecodeUint(state)); + *(*uint32)(p) = uint32(decodeUint(state)); } -func decInt64(i *decInstr, state *DecState, p unsafe.Pointer) { +func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int64)); } p = *(*unsafe.Pointer)(p); } - *(*int64)(p) = int64(DecodeInt(state)); + *(*int64)(p) = int64(decodeInt(state)); } -func decUint64(i *decInstr, state *DecState, p unsafe.Pointer) { +func decUint64(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint64)); } p = *(*unsafe.Pointer)(p); } - *(*uint64)(p) = uint64(DecodeUint(state)); + *(*uint64)(p) = uint64(decodeUint(state)); } -func decUintptr(i *decInstr, state *DecState, p unsafe.Pointer) { +func decUintptr(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uintptr)); } p = *(*unsafe.Pointer)(p); } - *(*uintptr)(p) = uintptr(DecodeUint(state)); + *(*uintptr)(p) = uintptr(decodeUint(state)); } // Floating-point numbers are transmitted as uint64s holding the bits @@ -224,59 +242,59 @@ func floatFromBits(u uint64) float64 { return math.Float64frombits(v); } -func decFloat(i *decInstr, state *DecState, p unsafe.Pointer) { +func decFloat(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float)); } p = *(*unsafe.Pointer)(p); } - *(*float)(p) = float(floatFromBits(uint64(DecodeUint(state)))); + *(*float)(p) = float(floatFromBits(uint64(decodeUint(state)))); } -func decFloat32(i *decInstr, state *DecState, p unsafe.Pointer) { +func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32)); } p = *(*unsafe.Pointer)(p); } - *(*float32)(p) = float32(floatFromBits(uint64(DecodeUint(state)))); + *(*float32)(p) = float32(floatFromBits(uint64(decodeUint(state)))); } -func decFloat64(i *decInstr, state *DecState, p unsafe.Pointer) { +func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float64)); } p = *(*unsafe.Pointer)(p); } - *(*float64)(p) = floatFromBits(uint64(DecodeUint(state))); + *(*float64)(p) = floatFromBits(uint64(decodeUint(state))); } // uint8 arrays are encoded as an unsigned count followed by the raw bytes. -func decUint8Array(i *decInstr, state *DecState, p unsafe.Pointer) { +func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]uint8)); } p = *(*unsafe.Pointer)(p); } - b := make([]uint8, DecodeUint(state)); - state.r.Read(b); + b := make([]uint8, decodeUint(state)); + state.b.Read(b); *(*[]uint8)(p) = b; } // Strings are encoded as an unsigned count followed by the raw bytes. -func decString(i *decInstr, state *DecState, p unsafe.Pointer) { +func decString(i *decInstr, state *decodeState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte)); } p = *(*unsafe.Pointer)(p); } - b := make([]byte, DecodeUint(state)); - state.r.Read(b); + b := make([]byte, decodeUint(state)); + state.b.Read(b); *(*string)(p) = string(b); } @@ -288,7 +306,7 @@ type decEngine struct { instr []decInstr } -func decodeStruct(engine *decEngine, rtyp *reflect.StructType, r io.Reader, p uintptr, indir int) os.Error { +func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error { if indir > 0 { up := unsafe.Pointer(p); if *(*unsafe.Pointer)(up) == nil { @@ -299,12 +317,12 @@ func decodeStruct(engine *decEngine, rtyp *reflect.StructType, r io.Reader, p ui } p = *(*uintptr)(up); } - state := new(DecState); - state.r = r; + state := new(decodeState); + state.b = b; state.fieldnum = -1; basep := p; for state.err == nil { - delta := int(DecodeUint(state)); + delta := int(decodeUint(state)); if delta < 0 { state.err = os.ErrorString("gob decode: corrupted data: negative delta"); break @@ -327,7 +345,7 @@ func decodeStruct(engine *decEngine, rtyp *reflect.StructType, r io.Reader, p ui return state.err } -func decodeArrayHelper(state *DecState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int) os.Error { +func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int) os.Error { instr := &decInstr{elemOp, 0, elemIndir, 0}; for i := 0; i < length && state.err == nil; i++ { up := unsafe.Pointer(p); @@ -340,7 +358,7 @@ func decodeArrayHelper(state *DecState, p uintptr, elemOp decOp, elemWid uintptr return state.err } -func decodeArray(atyp *reflect.ArrayType, state *DecState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int) os.Error { +func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int) os.Error { if indir > 0 { up := unsafe.Pointer(p); if *(*unsafe.Pointer)(up) == nil { @@ -351,14 +369,14 @@ func decodeArray(atyp *reflect.ArrayType, state *DecState, p uintptr, elemOp dec } p = *(*uintptr)(up); } - if n := DecodeUint(state); n != uint64(length) { - return os.ErrorString("length mismatch in decodeArray"); + if n := decodeUint(state); n != uint64(length) { + return os.ErrorString("gob: length mismatch in decodeArray"); } return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir); } -func decodeSlice(atyp *reflect.SliceType, state *DecState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int) os.Error { - length := uintptr(DecodeUint(state)); +func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int) os.Error { + length := uintptr(decodeUint(state)); if indir > 0 { up := unsafe.Pointer(p); if *(*unsafe.Pointer)(up) == nil { @@ -412,13 +430,13 @@ func decOpFor(rt reflect.Type) (decOp, int) { break; } elemOp, elemIndir := decOpFor(t.Elem()); - op = func(i *decInstr, state *DecState, p unsafe.Pointer) { + op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir); }; case *reflect.ArrayType: elemOp, elemIndir := decOpFor(t.Elem()); - op = func(i *decInstr, state *DecState, p unsafe.Pointer) { + op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir); }; @@ -426,9 +444,9 @@ func decOpFor(rt reflect.Type) (decOp, int) { // Generate a closure that calls out to the engine for the nested type. engine := getDecEngine(typ); info := getTypeInfo(typ); - op = func(i *decInstr, state *DecState, p unsafe.Pointer) { + op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { // indirect through info to delay evaluation for recursive structs - state.err = decodeStruct(info.decoder, t, state.r, uintptr(p), i.indir) + state.err = decodeStruct(info.decoder, t, state.b, uintptr(p), i.indir) }; } } @@ -473,7 +491,7 @@ func getDecEngine(rt reflect.Type) *decEngine { return info.decoder; } -func Decode(r io.Reader, e interface{}) os.Error { +func decode(b *bytes.Buffer, e interface{}) os.Error { // Dereference down to the underlying object. rt, indir := indirect(reflect.Typeof(e)); v := reflect.NewValue(e); @@ -481,10 +499,10 @@ func Decode(r io.Reader, e interface{}) os.Error { v = reflect.Indirect(v); } if _, ok := v.(*reflect.StructValue); !ok { - return os.ErrorString("decode can't handle " + rt.String()) + return os.ErrorString("gob: decode can't handle " + rt.String()) } typeLock.Lock(); engine := getDecEngine(rt); typeLock.Unlock(); - return decodeStruct(engine, rt.(*reflect.StructType), r, uintptr(v.Addr()), 0); + return decodeStruct(engine, rt.(*reflect.StructType), b, uintptr(v.Addr()), 0); } diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go index ef5481e109..e824ac754c 100644 --- a/src/pkg/gob/decoder.go +++ b/src/pkg/gob/decoder.go @@ -5,6 +5,7 @@ package gob import ( + "bytes"; "gob"; "io"; "os"; @@ -14,15 +15,19 @@ import ( type Decoder struct { sync.Mutex; // each item must be received atomically + r io.Reader; // source of the data seen map[TypeId] *wireType; // which types we've already seen described - state *DecState; // so we can encode integers, strings directly + state *decodeState; // reads data from in-memory buffer + countState *decodeState; // reads counts from wire + oneByte []byte; } func NewDecoder(r io.Reader) *Decoder { dec := new(Decoder); + dec.r = r; dec.seen = make(map[TypeId] *wireType); - dec.state = new(DecState); - dec.state.r = r; // the rest isn't important; all we need is buffer and reader + dec.state = new(decodeState); // buffer set in Decode(); rest is unimportant + dec.oneByte = make([]byte, 1); return dec; } @@ -36,7 +41,7 @@ func (dec *Decoder) recvType(id TypeId) { // Type: wire := new(wireType); - Decode(dec.state.r, wire); + decode(dec.state.b, wire); // Remember we've seen this type. dec.seen[id] = wire; } @@ -50,36 +55,56 @@ func (dec *Decoder) Decode(e interface{}) os.Error { dec.Lock(); defer dec.Unlock(); - var id TypeId; - for dec.state.err == nil { + dec.state.err = nil; + for { + // Read a count. + nbytes, err := decodeUintReader(dec.r, dec.oneByte); + if err != nil { + return err; + } + + // Read the data + buf := make([]byte, nbytes); // TODO(r): avoid repeated allocation + var n int; + n, err = dec.r.Read(buf); + if err != nil { + return err; + } + if n < int(nbytes) { + return os.ErrorString("gob decode: short read"); + } + + dec.state.b = bytes.NewBuffer(buf); // TODO(r): avoid repeated allocation // Receive a type id. - id = TypeId(DecodeInt(dec.state)); + id := TypeId(decodeInt(dec.state)); + if dec.state.err != nil { + return dec.state.err + } - // If the id is positive, we have a value. 0 is the error state - if id >= 0 { - break; + if id < 0 { // 0 is the error state, handled above + // If the id is negative, we have a type. + dec.recvType(-id); + if dec.state.err != nil { + return dec.state.err + } + continue; } - // The id is negative; a type descriptor follows. - dec.recvType(-id); - } - if dec.state.err != nil { - return dec.state.err - } + // we have a value + info := getTypeInfo(rt); + + // Check type compatibility. + // TODO(r): need to make the decoder work correctly if the wire type is compatible + // but not equal to the local type (e.g, extra fields). + if info.wire.name() != dec.seen[id].name() { + dec.state.err = os.ErrorString("gob decode: incorrect type for wire value: want " + info.wire.name() + "; received " + dec.seen[id].name()); + return dec.state.err + } - info := getTypeInfo(rt); + // Receive a value. + decode(dec.state.b, e); - // Check type compatibility. - // TODO(r): need to make the decoder work correctly if the wire type is compatible - // but not equal to the local type (e.g, extra fields). - if info.wire.name() != dec.seen[id].name() { - dec.state.err = os.ErrorString("gob decode: incorrect type for wire value: want " + info.wire.name() + "; received " + dec.seen[id].name()); return dec.state.err } - - // Receive a value. - Decode(dec.state.r, e); - - // Release and return. - return dec.state.err + return nil // silence compiler } diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index dac8097518..7f12658145 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -5,6 +5,7 @@ package gob import ( + "bytes"; "gob"; "io"; "math"; @@ -18,8 +19,8 @@ import ( // Field numbers are delta encoded and always increase. The field // number is initialized to -1 so 0 comes out as delta(1). A delta of // 0 terminates the structure. -type EncState struct { - w io.Writer; +type encoderState struct { + b *bytes.Buffer; err os.Error; // error encountered during encoding; fieldnum int; // the last field number written. buf [16]byte; // buffer used by the encoder; here to avoid allocation. @@ -30,37 +31,36 @@ type EncState struct { // That way there's only one bit to clear and the value is a little easier to see if // you're the unfortunate sort of person who must read the hex to debug. -// EncodeUint writes an encoded unsigned integer to state.w. Sets state.err. +// encodeUint writes an encoded unsigned integer to state.b. Sets state.err. // If state.err is already non-nil, it does nothing. -func EncodeUint(state *EncState, x uint64) { +func encodeUint(state *encoderState, x uint64) { var n int; if state.err != nil { return } - for n = 0; x > 127; n++ { + for n = 0; x > 0x7F; n++ { state.buf[n] = uint8(x & 0x7F); x >>= 7; } state.buf[n] = 0x80 | uint8(x); - var nn int; - nn, state.err = state.w.Write(state.buf[0:n+1]); + n, state.err = state.b.Write(state.buf[0:n+1]); } -// EncodeInt writes an encoded signed integer to state.w. +// encodeInt writes an encoded signed integer to state.w. // The low bit of the encoding says whether to bit complement the (other bits of the) uint to recover the int. // Sets state.err. If state.err is already non-nil, it does nothing. -func EncodeInt(state *EncState, i int64){ +func encodeInt(state *encoderState, i int64){ var x uint64; if i < 0 { x = uint64(^i << 1) | 1 } else { x = uint64(i << 1) } - EncodeUint(state, uint64(x)) + encodeUint(state, uint64(x)) } type encInstr struct -type encOp func(i *encInstr, state *EncState, p unsafe.Pointer) +type encOp func(i *encInstr, state *encoderState, p unsafe.Pointer) // The 'instructions' of the encoding machine type encInstr struct { @@ -72,9 +72,9 @@ type encInstr struct { // Emit a field number and update the state to record its value for delta encoding. // If the instruction pointer is nil, do nothing -func (state *EncState) update(instr *encInstr) { +func (state *encoderState) update(instr *encInstr) { if instr != nil { - EncodeUint(state, uint64(instr.field - state.fieldnum)); + encodeUint(state, uint64(instr.field - state.fieldnum)); state.fieldnum = instr.field; } } @@ -95,99 +95,99 @@ func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { return p } -func encBool(i *encInstr, state *EncState, p unsafe.Pointer) { +func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*bool)(p); if b { state.update(i); - EncodeUint(state, 1); + encodeUint(state, 1); } } -func encInt(i *encInstr, state *EncState, p unsafe.Pointer) { +func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int)(p)); if v != 0 { state.update(i); - EncodeInt(state, v); + encodeInt(state, v); } } -func encUint(i *encInstr, state *EncState, p unsafe.Pointer) { +func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint)(p)); if v != 0 { state.update(i); - EncodeUint(state, v); + encodeUint(state, v); } } -func encInt8(i *encInstr, state *EncState, p unsafe.Pointer) { +func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int8)(p)); if v != 0 { state.update(i); - EncodeInt(state, v); + encodeInt(state, v); } } -func encUint8(i *encInstr, state *EncState, p unsafe.Pointer) { +func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint8)(p)); if v != 0 { state.update(i); - EncodeUint(state, v); + encodeUint(state, v); } } -func encInt16(i *encInstr, state *EncState, p unsafe.Pointer) { +func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int16)(p)); if v != 0 { state.update(i); - EncodeInt(state, v); + encodeInt(state, v); } } -func encUint16(i *encInstr, state *EncState, p unsafe.Pointer) { +func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint16)(p)); if v != 0 { state.update(i); - EncodeUint(state, v); + encodeUint(state, v); } } -func encInt32(i *encInstr, state *EncState, p unsafe.Pointer) { +func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int32)(p)); if v != 0 { state.update(i); - EncodeInt(state, v); + encodeInt(state, v); } } -func encUint32(i *encInstr, state *EncState, p unsafe.Pointer) { +func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint32)(p)); if v != 0 { state.update(i); - EncodeUint(state, v); + encodeUint(state, v); } } -func encInt64(i *encInstr, state *EncState, p unsafe.Pointer) { +func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) { v := *(*int64)(p); if v != 0 { state.update(i); - EncodeInt(state, v); + encodeInt(state, v); } } -func encUint64(i *encInstr, state *EncState, p unsafe.Pointer) { +func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) { v := *(*uint64)(p); if v != 0 { state.update(i); - EncodeUint(state, v); + encodeUint(state, v); } } -func encUintptr(i *encInstr, state *EncState, p unsafe.Pointer) { +func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uintptr)(p)); if v != 0 { state.update(i); - EncodeUint(state, v); + encodeUint(state, v); } } @@ -207,56 +207,56 @@ func floatBits(f float64) uint64 { return v; } -func encFloat(i *encInstr, state *EncState, p unsafe.Pointer) { +func encFloat(i *encInstr, state *encoderState, p unsafe.Pointer) { f := float(*(*float)(p)); if f != 0 { v := floatBits(float64(f)); state.update(i); - EncodeUint(state, v); + encodeUint(state, v); } } -func encFloat32(i *encInstr, state *EncState, p unsafe.Pointer) { +func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { f := float32(*(*float32)(p)); if f != 0 { v := floatBits(float64(f)); state.update(i); - EncodeUint(state, v); + encodeUint(state, v); } } -func encFloat64(i *encInstr, state *EncState, p unsafe.Pointer) { +func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { f := *(*float64)(p); if f != 0 { state.update(i); v := floatBits(f); - EncodeUint(state, v); + encodeUint(state, v); } } // Byte arrays are encoded as an unsigned count followed by the raw bytes. -func encUint8Array(i *encInstr, state *EncState, p unsafe.Pointer) { +func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*[]byte)(p); if len(b) > 0 { state.update(i); - EncodeUint(state, uint64(len(b))); - state.w.Write(b); + encodeUint(state, uint64(len(b))); + state.b.Write(b); } } // Strings are encoded as an unsigned count followed by the raw bytes. -func encString(i *encInstr, state *EncState, p unsafe.Pointer) { +func encString(i *encInstr, state *encoderState, p unsafe.Pointer) { s := *(*string)(p); if len(s) > 0 { state.update(i); - EncodeUint(state, uint64(len(s))); - io.WriteString(state.w, s); + encodeUint(state, uint64(len(s))); + io.WriteString(state.b, s); } } // The end of a struct is marked by a delta field number of 0. -func encStructTerminator(i *encInstr, state *EncState, p unsafe.Pointer) { - EncodeUint(state, 0); +func encStructTerminator(i *encInstr, state *encoderState, p unsafe.Pointer) { + encodeUint(state, 0); } // Execution engine @@ -267,9 +267,9 @@ type encEngine struct { instr []encInstr } -func encodeStruct(engine *encEngine, w io.Writer, basep uintptr) os.Error { - state := new(EncState); - state.w = w; +func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error { + state := new(encoderState); + state.b = b; state.fieldnum = -1; for i := 0; i < len(engine.instr); i++ { instr := &engine.instr[i]; @@ -287,17 +287,17 @@ func encodeStruct(engine *encEngine, w io.Writer, basep uintptr) os.Error { return state.err } -func encodeArray(w io.Writer, p uintptr, op encOp, elemWid uintptr, length int, elemIndir int) os.Error { - state := new(EncState); - state.w = w; +func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, length int, elemIndir int) os.Error { + state := new(encoderState); + state.b = b; state.fieldnum = -1; - EncodeUint(state, uint64(length)); + encodeUint(state, uint64(length)); for i := 0; i < length && state.err == nil; i++ { elemp := p; up := unsafe.Pointer(elemp); if elemIndir > 0 { if up = encIndirect(up, elemIndir); up == nil { - state.err = os.ErrorString("encodeArray: nil element"); + state.err = os.ErrorString("gob: encodeArray: nil element"); break } elemp = uintptr(up); @@ -345,29 +345,29 @@ func encOpFor(rt reflect.Type) (encOp, int) { } // Slices have a header; we decode it to find the underlying array. elemOp, indir := encOpFor(t.Elem()); - op = func(i *encInstr, state *EncState, p unsafe.Pointer) { + op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { slice := (*reflect.SliceHeader)(p); if slice.Len == 0 { return } state.update(i); - state.err = encodeArray(state.w, slice.Data, elemOp, t.Elem().Size(), int(slice.Len), indir); + state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), int(slice.Len), indir); }; case *reflect.ArrayType: // True arrays have size in the type. elemOp, indir := encOpFor(t.Elem()); - op = func(i *encInstr, state *EncState, p unsafe.Pointer) { + op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i); - state.err = encodeArray(state.w, uintptr(p), elemOp, t.Elem().Size(), t.Len(), indir); + state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), t.Len(), indir); }; case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. engine := getEncEngine(typ); info := getTypeInfo(typ); - op = func(i *encInstr, state *EncState, p unsafe.Pointer) { + op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i); // indirect through info to delay evaluation for recursive structs - state.err = encodeStruct(info.encoder, state.w, uintptr(p)); + state.err = encodeStruct(info.encoder, state.b, uintptr(p)); }; } } @@ -406,7 +406,7 @@ func getEncEngine(rt reflect.Type) *encEngine { return info.encoder; } -func Encode(w io.Writer, e interface{}) os.Error { +func encode(b *bytes.Buffer, e interface{}) os.Error { // Dereference down to the underlying object. rt, indir := indirect(reflect.Typeof(e)); v := reflect.NewValue(e); @@ -414,10 +414,10 @@ func Encode(w io.Writer, e interface{}) os.Error { v = reflect.Indirect(v); } if _, ok := v.(*reflect.StructValue); !ok { - return os.ErrorString("encode can't handle " + v.Type().String()) + return os.ErrorString("gob: encode can't handle " + v.Type().String()) } typeLock.Lock(); engine := getEncEngine(rt); typeLock.Unlock(); - return encodeStruct(engine, w, v.Addr()); + return encodeStruct(engine, b, v.Addr()); } diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index 30ec819c77..b3a420a86a 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -5,6 +5,7 @@ package gob import ( + "bytes"; "gob"; "io"; "os"; @@ -14,20 +15,45 @@ import ( type Encoder struct { sync.Mutex; // each item must be sent atomically + w io.Writer; // where to send the data sent map[reflect.Type] TypeId; // which types we've already sent - state *EncState; // so we can encode integers, strings directly + state *encoderState; // so we can encode integers, strings directly + countState *encoderState; // stage for writing counts + buf []byte; // for collecting the output. } func NewEncoder(w io.Writer) *Encoder { enc := new(Encoder); + enc.w = w; enc.sent = make(map[reflect.Type] TypeId); - enc.state = new(EncState); - enc.state.w = w; // the rest isn't important; all we need is buffer and writer + enc.state = new(encoderState); + enc.state.b = new(bytes.Buffer); // the rest isn't important; all we need is buffer and writer + enc.countState = new(encoderState); + enc.countState.b = new(bytes.Buffer); // the rest isn't important; all we need is buffer and writer return enc; } func (enc *Encoder) badType(rt reflect.Type) { - enc.state.err = os.ErrorString("can't encode type " + rt.String()); + enc.state.err = os.ErrorString("gob: can't encode type " + rt.String()); +} + +// Send the data item preceded by a unsigned count of its length. +func (enc *Encoder) send() { + // Encode the length. + encodeUint(enc.countState, uint64(enc.state.b.Len())); + // Build the buffer. + countLen := enc.countState.b.Len(); + total := countLen + enc.state.b.Len(); + if total > len(enc.buf) { + enc.buf = make([]byte, total+1000); // extra for growth + } + // Place the length before the data. + // TODO(r): avoid the extra copy here. + enc.countState.b.Read(enc.buf[0:countLen]); + // Now the data. + enc.state.b.Read(enc.buf[countLen:total]); + // Write the data. + enc.w.Write(enc.buf[0:total]); } func (enc *Encoder) sendType(origt reflect.Type) { @@ -63,9 +89,11 @@ func (enc *Encoder) sendType(origt reflect.Type) { info := getTypeInfo(rt); // Send the pair (-id, type) // Id: - EncodeInt(enc.state, -int64(info.typeId)); + encodeInt(enc.state, -int64(info.typeId)); // Type: - Encode(enc.state.w, info.wire); + encode(enc.state.b, info.wire); + enc.send(); + // Remember we've sent this type. enc.sent[rt] = info.typeId; // Remember we've sent the top-level, possibly indirect type too. @@ -78,6 +106,9 @@ func (enc *Encoder) sendType(origt reflect.Type) { } func (enc *Encoder) Encode(e interface{}) os.Error { + if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 { + panicln("Encoder: buffer not empty") + } rt, indir := indirect(reflect.Typeof(e)); // Make sure we're single-threaded through here. @@ -90,16 +121,18 @@ func (enc *Encoder) Encode(e interface{}) os.Error { // No, so send it. enc.sendType(rt); if enc.state.err != nil { + enc.state.b.Reset(); + enc.countState.b.Reset(); return enc.state.err } } // Identify the type of this top-level value. - EncodeInt(enc.state, int64(enc.sent[rt])); + encodeInt(enc.state, int64(enc.sent[rt])); - // Finally, send the data - Encode(enc.state.w, e); + // Encode the object. + encode(enc.state.b, e); + enc.send(); - // Release and return. return enc.state.err } diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go index 1640ac72a5..b4e9f5b553 100644 --- a/src/pkg/gob/encoder_test.go +++ b/src/pkg/gob/encoder_test.go @@ -57,17 +57,22 @@ func TestBasicEncoder(t *testing.T) { } // Decode the result by hand to verify; - state := new(DecState); - state.r = b; + state := new(decodeState); + state.b = b; // The output should be: + // 0) The length, 38. + length := decodeUint(state); + if length != 38 { + t.Fatal("0. expected length 38; got", length); + } // 1) -7: the type id of ET1 - id1 := DecodeInt(state); + id1 := decodeInt(state); if id1 >= 0 { t.Fatal("expected ET1 negative id; got", id1); } // 2) The wireType for ET1 wire1 := new(wireType); - err := Decode(b, wire1); + err := decode(b, wire1); if err != nil { t.Fatal("error decoding ET1 type:", err); } @@ -76,14 +81,19 @@ func TestBasicEncoder(t *testing.T) { if !reflect.DeepEqual(wire1, trueWire1) { t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1); } - // 3) -8: the type id of ET2 - id2 := DecodeInt(state); + // 3) The length, 21. + length = decodeUint(state); + if length != 21 { + t.Fatal("3. expected length 21; got", length); + } + // 4) -8: the type id of ET2 + id2 := decodeInt(state); if id2 >= 0 { t.Fatal("expected ET2 negative id; got", id2); } - // 4) The wireType for ET2 + // 5) The wireType for ET2 wire2 := new(wireType); - err = Decode(b, wire2); + err = decode(b, wire2); if err != nil { t.Fatal("error decoding ET2 type:", err); } @@ -92,21 +102,26 @@ func TestBasicEncoder(t *testing.T) { if !reflect.DeepEqual(wire2, trueWire2) { t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2); } - // 5) The type id for the et1 value - newId1 := DecodeInt(state); + // 6) The length, 6. + length = decodeUint(state); + if length != 6 { + t.Fatal("6. expected length 6; got", length); + } + // 7) The type id for the et1 value + newId1 := decodeInt(state); if newId1 != -id1 { t.Fatal("expected Et1 id", -id1, "got", newId1); } - // 6) The value of et1 + // 8) The value of et1 newEt1 := new(ET1); - err = Decode(b, newEt1); + err = decode(b, newEt1); if err != nil { t.Fatal("error decoding ET1 value:", err); } if !reflect.DeepEqual(et1, newEt1) { t.Fatalf("invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1); } - // 7) EOF + // 9) EOF if b.Len() != 0 { t.Error("not at eof;", b.Len(), "bytes left") } @@ -117,14 +132,19 @@ func TestBasicEncoder(t *testing.T) { if enc.state.err != nil { t.Error("2nd round: encoder fail:", enc.state.err) } + // The length. + length = decodeUint(state); + if length != 6 { + t.Fatal("6. expected length 6; got", length); + } // 5a) The type id for the et1 value - newId1 = DecodeInt(state); + newId1 = decodeInt(state); if newId1 != -id1 { t.Fatal("2nd round: expected Et1 id", -id1, "got", newId1); } // 6a) The value of et1 newEt1 = new(ET1); - err = Decode(b, newEt1); + err = decode(b, newEt1); if err != nil { t.Fatal("2nd round: error decoding ET1 value:", err); } diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index 7eaae05a1b..00ff82494d 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -17,7 +17,7 @@ import ( // Internally, they are used as keys to a map to recover the underlying type info. type TypeId int32 -var id TypeId // incremented for each new type we build +var nextId TypeId // incremented for each new type we build var typeLock sync.Mutex // set while building a type type gobType interface { @@ -31,9 +31,9 @@ var types = make(map[reflect.Type] gobType) var idToType = make(map[TypeId] gobType) func setTypeId(typ gobType) { - id++; - typ.setId(id); - idToType[id] = typ; + nextId++; + typ.setId(nextId); + idToType[nextId] = typ; } func (t TypeId) gobType() gobType { @@ -296,7 +296,7 @@ func bootstrapType(name string, e interface{}) TypeId { typ := &commonType{ name: name }; types[rt] = typ; setTypeId(typ); - return id + return nextId } // Representation of the information we send and receive about this type. -- 2.48.1