From ce3ba126a0c5370c2c4c4e1ef32291316b96b5b0 Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Wed, 14 Aug 2013 00:18:48 -0400 Subject: [PATCH] encoding/gob: support new generic interfaces in package encoding R=r CC=golang-dev https://golang.org/cl/12681044 --- src/pkg/encoding/gob/debug.go | 10 ++ src/pkg/encoding/gob/decode.go | 40 ++++--- src/pkg/encoding/gob/doc.go | 35 +++--- src/pkg/encoding/gob/encode.go | 29 +++-- src/pkg/encoding/gob/encoder.go | 4 +- src/pkg/encoding/gob/gobencdec_test.go | 144 +++++++++++++++++++++---- src/pkg/encoding/gob/type.go | 83 ++++++++++---- 7 files changed, 269 insertions(+), 76 deletions(-) diff --git a/src/pkg/encoding/gob/debug.go b/src/pkg/encoding/gob/debug.go index 31d1351fc4..6117eb0837 100644 --- a/src/pkg/encoding/gob/debug.go +++ b/src/pkg/encoding/gob/debug.go @@ -415,6 +415,16 @@ func (deb *debugger) typeDefinition(indent tab, id typeId) { deb.delta(1) com := deb.common() wire.GobEncoderT = &gobEncoderType{com} + case 5: // BinaryMarshaler type, one field of {{Common}} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + wire.BinaryMarshalerT = &gobEncoderType{com} + case 6: // TextMarshaler type, one field of {{Common}} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + wire.TextMarshalerT = &gobEncoderType{com} default: errorf("bad field in type %d", fieldNum) } diff --git a/src/pkg/encoding/gob/decode.go b/src/pkg/encoding/gob/decode.go index 08829a4a0a..3e76f4c906 100644 --- a/src/pkg/encoding/gob/decode.go +++ b/src/pkg/encoding/gob/decode.go @@ -9,6 +9,7 @@ package gob import ( "bytes" + "encoding" "errors" "io" "math" @@ -767,15 +768,22 @@ func (dec *Decoder) ignoreInterface(state *decoderState) { // decodeGobDecoder decodes something implementing the GobDecoder interface. // The data is encoded as a byte slice. -func (dec *Decoder) decodeGobDecoder(state *decoderState, v reflect.Value) { +func (dec *Decoder) decodeGobDecoder(ut *userTypeInfo, state *decoderState, v reflect.Value) { // Read the bytes for the value. b := make([]byte, state.decodeUint()) _, err := state.b.Read(b) if err != nil { error_(err) } - // We know it's a GobDecoder, so just call the method directly. - err = v.Interface().(GobDecoder).GobDecode(b) + // We know it's one of these. + switch ut.externalDec { + case xGob: + err = v.Interface().(GobDecoder).GobDecode(b) + case xBinary: + err = v.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(b) + case xText: + err = v.Interface().(encoding.TextUnmarshaler).UnmarshalText(b) + } if err != nil { error_(err) } @@ -825,9 +833,10 @@ var decIgnoreOpMap = map[typeId]decOp{ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) { ut := userType(rt) // If the type implements GobEncoder, we handle it without further processing. - if ut.isGobDecoder { + if ut.externalDec != 0 { return dec.gobDecodeOpFor(ut) } + // If this type is already in progress, it's a recursive type (e.g. map[string]*T). // Return the pointer to the op we're already building. if opPtr := inProgress[rt]; opPtr != nil { @@ -954,7 +963,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { state.dec.ignoreStruct(*enginePtr) } - case wire.GobEncoderT != nil: + case wire.GobEncoderT != nil, wire.BinaryMarshalerT != nil, wire.TextMarshalerT != nil: op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.ignoreGobDecoder(state) } @@ -993,7 +1002,7 @@ func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { } else { v = reflect.NewAt(rcvrType, p).Elem() } - state.dec.decodeGobDecoder(state, v) + state.dec.decodeGobDecoder(ut, state, v) } return &op, int(ut.indir) @@ -1010,12 +1019,18 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re inProgress[fr] = fw ut := userType(fr) wire, ok := dec.wireType[fw] - // If fr is a GobDecoder, the wire type must be GobEncoder. - // And if fr is not a GobDecoder, the wire type must not be either. - if ut.isGobDecoder != (ok && wire.GobEncoderT != nil) { // the parentheses look odd but are correct. + // If wire was encoded with an encoding method, fr must have that method. + // And if not, it must not. + // At most one of the booleans in ut is set. + // We could possibly relax this constraint in the future in order to + // choose the decoding method using the data in the wireType. + // The parentheses look odd but are correct. + if (ut.externalDec == xGob) != (ok && wire.GobEncoderT != nil) || + (ut.externalDec == xBinary) != (ok && wire.BinaryMarshalerT != nil) || + (ut.externalDec == xText) != (ok && wire.TextMarshalerT != nil) { return false } - if ut.isGobDecoder { // This test trumps all others. + if ut.externalDec != 0 { // This test trumps all others. return true } switch t := ut.base; t.Kind() { @@ -1114,8 +1129,7 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err error) { rt := ut.base srt := rt - if srt.Kind() != reflect.Struct || - ut.isGobDecoder { + if srt.Kind() != reflect.Struct || ut.externalDec != 0 { return dec.compileSingle(remoteId, ut) } var wireStruct *structType @@ -1223,7 +1237,7 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) { return } engine := *enginePtr - if st := base; st.Kind() == reflect.Struct && !ut.isGobDecoder { + if st := base; st.Kind() == reflect.Struct && ut.externalDec == 0 { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { name := base.Name() errorf("type mismatch: no fields matched compiling decoder for %s", name) diff --git a/src/pkg/encoding/gob/doc.go b/src/pkg/encoding/gob/doc.go index 5bd61b12eb..dc0e325f97 100644 --- a/src/pkg/encoding/gob/doc.go +++ b/src/pkg/encoding/gob/doc.go @@ -67,19 +67,28 @@ point values may be received into any floating point variable. However, the destination variable must be able to represent the value or the decode operation will fail. -Structs, arrays and slices are also supported. Structs encode and -decode only exported fields. Strings and arrays of bytes are supported -with a special, efficient representation (see below). When a slice -is decoded, if the existing slice has capacity the slice will be -extended in place; if not, a new array is allocated. Regardless, -the length of the resulting slice reports the number of elements -decoded. - -Functions and channels cannot be sent in a gob. Attempting -to encode a value that contains one will fail. - -The rest of this comment documents the encoding, details that are not important -for most users. Details are presented bottom-up. +Structs, arrays and slices are also supported. Structs encode and decode only +exported fields. Strings and arrays of bytes are supported with a special, +efficient representation (see below). When a slice is decoded, if the existing +slice has capacity the slice will be extended in place; if not, a new array is +allocated. Regardless, the length of the resulting slice reports the number of +elements decoded. + +Functions and channels cannot be sent in a gob. Attempting to encode a value +that contains one will fail. + +Gob can encode a value of any type implementing the GobEncoder, +encoding.BinaryMarshaler, or encoding.TextMarshaler interfaces by calling the +corresponding method, in that order of preference. + +Gob can decode a value of any type implementing the GobDecoder, +encoding.BinaryUnmarshaler, or encoding.TextUnmarshaler interfaces by calling +the corresponding method, again in that order of preference. + +Encoding Details + +This section documents the encoding, details that are not important for most +users. Details are presented bottom-up. An unsigned integer is sent one of two ways. If it is less than 128, it is sent as a byte with that value. Otherwise it is sent as a minimal-length big-endian diff --git a/src/pkg/encoding/gob/encode.go b/src/pkg/encoding/gob/encode.go index ee9b0783e0..480faa305d 100644 --- a/src/pkg/encoding/gob/encode.go +++ b/src/pkg/encoding/gob/encode.go @@ -6,6 +6,7 @@ package gob import ( "bytes" + "encoding" "math" "reflect" "unsafe" @@ -511,10 +512,20 @@ func isZero(val reflect.Value) bool { // encGobEncoder encodes a value that implements the GobEncoder interface. // The data is sent as a byte array. -func (enc *Encoder) encodeGobEncoder(b *bytes.Buffer, v reflect.Value) { +func (enc *Encoder) encodeGobEncoder(b *bytes.Buffer, ut *userTypeInfo, v reflect.Value) { // TODO: should we catch panics from the called method? - // We know it's a GobEncoder, so just call the method directly. - data, err := v.Interface().(GobEncoder).GobEncode() + + var data []byte + var err error + // We know it's one of these. + switch ut.externalEnc { + case xGob: + data, err = v.Interface().(GobEncoder).GobEncode() + case xBinary: + data, err = v.Interface().(encoding.BinaryMarshaler).MarshalBinary() + case xText: + data, err = v.Interface().(encoding.TextMarshaler).MarshalText() + } if err != nil { error_(err) } @@ -550,7 +561,7 @@ var encOpTable = [...]encOp{ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) { ut := userType(rt) // If the type implements GobEncoder, we handle it without further processing. - if ut.isGobEncoder { + if ut.externalEnc != 0 { return enc.gobEncodeOpFor(ut) } // If this type is already in progress, it's a recursive type (e.g. map[string]*T). @@ -661,7 +672,7 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { return } state.update(i) - state.enc.encodeGobEncoder(state.b, v) + state.enc.encodeGobEncoder(state.b, ut, v) } return &op, int(ut.encIndir) // encIndir: op will get called with p == address of receiver. } @@ -672,10 +683,10 @@ func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine { engine := new(encEngine) seen := make(map[reflect.Type]*encOp) rt := ut.base - if ut.isGobEncoder { + if ut.externalEnc != 0 { rt = ut.user } - if !ut.isGobEncoder && + if ut.externalEnc == 0 && srt.Kind() == reflect.Struct { for fieldNum, wireFieldNum := 0, 0; fieldNum < srt.NumField(); fieldNum++ { f := srt.Field(fieldNum) @@ -736,13 +747,13 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf defer catchError(&enc.err) engine := enc.lockAndGetEncEngine(ut) indir := ut.indir - if ut.isGobEncoder { + if ut.externalEnc != 0 { indir = int(ut.encIndir) } for i := 0; i < indir; i++ { value = reflect.Indirect(value) } - if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct { + if ut.externalEnc == 0 && value.Type().Kind() == reflect.Struct { enc.encodeStruct(b, engine, unsafeAddr(value)) } else { enc.encodeSingle(b, engine, unsafeAddr(value)) diff --git a/src/pkg/encoding/gob/encoder.go b/src/pkg/encoding/gob/encoder.go index f669c3d5b2..332a607c2b 100644 --- a/src/pkg/encoding/gob/encoder.go +++ b/src/pkg/encoding/gob/encoder.go @@ -135,7 +135,7 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp // sendType sends the type info to the other side, if necessary. func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { ut := userType(origt) - if ut.isGobEncoder { + if ut.externalEnc != 0 { // The rules are different: regardless of the underlying type's representation, // we need to tell the other side that the base type is a GobEncoder. return enc.sendActualType(w, state, ut, ut.base) @@ -184,7 +184,7 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use // Make sure the type is known to the other side. // First, have we already sent this type? rt := ut.base - if ut.isGobEncoder { + if ut.externalEnc != 0 { rt = ut.user } if _, alreadySent := enc.sent[rt]; !alreadySent { diff --git a/src/pkg/encoding/gob/gobencdec_test.go b/src/pkg/encoding/gob/gobencdec_test.go index ddcd80b1a7..4e49aeda21 100644 --- a/src/pkg/encoding/gob/gobencdec_test.go +++ b/src/pkg/encoding/gob/gobencdec_test.go @@ -34,6 +34,14 @@ type Gobber int type ValueGobber string // encodes with a value, decodes with a pointer. +type BinaryGobber int + +type BinaryValueGobber string + +type TextGobber int + +type TextValueGobber string + // The relevant methods func (g *ByteStruct) GobEncode() ([]byte, error) { @@ -101,6 +109,24 @@ func (g *Gobber) GobDecode(data []byte) error { return err } +func (g *BinaryGobber) MarshalBinary() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%d", *g)), nil +} + +func (g *BinaryGobber) UnmarshalBinary(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g)) + return err +} + +func (g *TextGobber) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%d", *g)), nil +} + +func (g *TextGobber) UnmarshalText(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g)) + return err +} + func (v ValueGobber) GobEncode() ([]byte, error) { return []byte(fmt.Sprintf("VALUE=%s", v)), nil } @@ -110,6 +136,24 @@ func (v *ValueGobber) GobDecode(data []byte) error { return err } +func (v BinaryValueGobber) MarshalBinary() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%s", v)), nil +} + +func (v *BinaryValueGobber) UnmarshalBinary(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v)) + return err +} + +func (v TextValueGobber) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("VALUE=%s", v)), nil +} + +func (v *TextValueGobber) UnmarshalText(data []byte) error { + _, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v)) + return err +} + // Structs that include GobEncodable fields. type GobTest0 struct { @@ -130,28 +174,42 @@ type GobTest2 struct { type GobTest3 struct { X int // guarantee we have something in common with GobTest* G *Gobber + B *BinaryGobber + T *TextGobber } type GobTest4 struct { - X int // guarantee we have something in common with GobTest* - V ValueGobber + X int // guarantee we have something in common with GobTest* + V ValueGobber + BV BinaryValueGobber + TV TextValueGobber } type GobTest5 struct { - X int // guarantee we have something in common with GobTest* - V *ValueGobber + X int // guarantee we have something in common with GobTest* + V *ValueGobber + BV *BinaryValueGobber + TV *TextValueGobber } type GobTest6 struct { - X int // guarantee we have something in common with GobTest* - V ValueGobber - W *ValueGobber + X int // guarantee we have something in common with GobTest* + V ValueGobber + W *ValueGobber + BV BinaryValueGobber + BW *BinaryValueGobber + TV TextValueGobber + TW *TextValueGobber } type GobTest7 struct { - X int // guarantee we have something in common with GobTest* - V *ValueGobber - W ValueGobber + X int // guarantee we have something in common with GobTest* + V *ValueGobber + W ValueGobber + BV *BinaryValueGobber + BW BinaryValueGobber + TV *TextValueGobber + TW TextValueGobber } type GobTestIgnoreEncoder struct { @@ -198,7 +256,9 @@ func TestGobEncoderField(t *testing.T) { // Now a field that's not a structure. b.Reset() gobber := Gobber(23) - err = enc.Encode(GobTest3{17, &gobber}) + bgobber := BinaryGobber(24) + tgobber := TextGobber(25) + err = enc.Encode(GobTest3{17, &gobber, &bgobber, &tgobber}) if err != nil { t.Fatal("encode error:", err) } @@ -207,7 +267,7 @@ func TestGobEncoderField(t *testing.T) { if err != nil { t.Fatal("decode error:", err) } - if *y.G != 23 { + if *y.G != 23 || *y.B != 24 || *y.T != 25 { t.Errorf("expected '23 got %d", *y.G) } } @@ -357,7 +417,7 @@ func TestGobEncoderValueEncoder(t *testing.T) { // first, string in field to byte in field b := new(bytes.Buffer) enc := NewEncoder(b) - err := enc.Encode(GobTest4{17, ValueGobber("hello")}) + err := enc.Encode(GobTest4{17, ValueGobber("hello"), BinaryValueGobber("Καλημέρα"), TextValueGobber("こんにちは")}) if err != nil { t.Fatal("encode error:", err) } @@ -367,7 +427,7 @@ func TestGobEncoderValueEncoder(t *testing.T) { if err != nil { t.Fatal("decode error:", err) } - if *x.V != "hello" { + if *x.V != "hello" || *x.BV != "Καλημέρα" || *x.TV != "こんにちは" { t.Errorf("expected `hello` got %s", x.V) } } @@ -377,13 +437,17 @@ func TestGobEncoderValueEncoder(t *testing.T) { func TestGobEncoderValueThenPointer(t *testing.T) { v := ValueGobber("forty-two") w := ValueGobber("six-by-nine") + bv := BinaryValueGobber("1nanocentury") + bw := BinaryValueGobber("πseconds") + tv := TextValueGobber("gravitationalacceleration") + tw := TextValueGobber("π²ft/s²") // this was a bug: encoding a GobEncoder by value before a GobEncoder // pointer would cause duplicate type definitions to be sent. b := new(bytes.Buffer) enc := NewEncoder(b) - if err := enc.Encode(GobTest6{42, v, &w}); err != nil { + if err := enc.Encode(GobTest6{42, v, &w, bv, &bw, tv, &tw}); err != nil { t.Fatal("encode error:", err) } dec := NewDecoder(b) @@ -391,6 +455,7 @@ func TestGobEncoderValueThenPointer(t *testing.T) { if err := dec.Decode(x); err != nil { t.Fatal("decode error:", err) } + if got, want := x.V, v; got != want { t.Errorf("v = %q, want %q", got, want) } @@ -399,6 +464,24 @@ func TestGobEncoderValueThenPointer(t *testing.T) { } else if *got != want { t.Errorf("w = %q, want %q", *got, want) } + + if got, want := x.BV, bv; got != want { + t.Errorf("bv = %q, want %q", got, want) + } + if got, want := x.BW, bw; got == nil { + t.Errorf("bw = nil, want %q", want) + } else if *got != want { + t.Errorf("bw = %q, want %q", *got, want) + } + + if got, want := x.TV, tv; got != want { + t.Errorf("tv = %q, want %q", got, want) + } + if got, want := x.TW, tw; got == nil { + t.Errorf("tw = nil, want %q", want) + } else if *got != want { + t.Errorf("tw = %q, want %q", *got, want) + } } // Test that we can use a pointer then a value type of a GobEncoder @@ -406,10 +489,14 @@ func TestGobEncoderValueThenPointer(t *testing.T) { func TestGobEncoderPointerThenValue(t *testing.T) { v := ValueGobber("forty-two") w := ValueGobber("six-by-nine") + bv := BinaryValueGobber("1nanocentury") + bw := BinaryValueGobber("πseconds") + tv := TextValueGobber("gravitationalacceleration") + tw := TextValueGobber("π²ft/s²") b := new(bytes.Buffer) enc := NewEncoder(b) - if err := enc.Encode(GobTest7{42, &v, w}); err != nil { + if err := enc.Encode(GobTest7{42, &v, w, &bv, bw, &tv, tw}); err != nil { t.Fatal("encode error:", err) } dec := NewDecoder(b) @@ -417,14 +504,33 @@ func TestGobEncoderPointerThenValue(t *testing.T) { if err := dec.Decode(x); err != nil { t.Fatal("decode error:", err) } + if got, want := x.V, v; got == nil { t.Errorf("v = nil, want %q", want) } else if *got != want { - t.Errorf("v = %q, want %q", got, want) + t.Errorf("v = %q, want %q", *got, want) } if got, want := x.W, w; got != want { t.Errorf("w = %q, want %q", got, want) } + + if got, want := x.BV, bv; got == nil { + t.Errorf("bv = nil, want %q", want) + } else if *got != want { + t.Errorf("bv = %q, want %q", *got, want) + } + if got, want := x.BW, bw; got != want { + t.Errorf("bw = %q, want %q", got, want) + } + + if got, want := x.TV, tv; got == nil { + t.Errorf("tv = nil, want %q", want) + } else if *got != want { + t.Errorf("tv = %q, want %q", *got, want) + } + if got, want := x.TW, tw; got != want { + t.Errorf("tw = %q, want %q", got, want) + } } func TestGobEncoderFieldTypeError(t *testing.T) { @@ -521,7 +627,9 @@ func TestGobEncoderIgnoreNonStructField(t *testing.T) { // First a field that's a structure. enc := NewEncoder(b) gobber := Gobber(23) - err := enc.Encode(GobTest3{17, &gobber}) + bgobber := BinaryGobber(24) + tgobber := TextGobber(25) + err := enc.Encode(GobTest3{17, &gobber, &bgobber, &tgobber}) if err != nil { t.Fatal("encode error:", err) } diff --git a/src/pkg/encoding/gob/type.go b/src/pkg/encoding/gob/type.go index 7fa0b499f0..0e49b30d70 100644 --- a/src/pkg/encoding/gob/type.go +++ b/src/pkg/encoding/gob/type.go @@ -5,6 +5,7 @@ package gob import ( + "encoding" "errors" "fmt" "os" @@ -18,14 +19,21 @@ import ( // to the package. It's computed once and stored in a map keyed by reflection // type. type userTypeInfo struct { - user reflect.Type // the type the user handed us - base reflect.Type // the base type after all indirections - indir int // number of indirections to reach the base type - isGobEncoder bool // does the type implement GobEncoder? - isGobDecoder bool // does the type implement GobDecoder? - encIndir int8 // number of indirections to reach the receiver type; may be negative - decIndir int8 // number of indirections to reach the receiver type; may be negative -} + user reflect.Type // the type the user handed us + base reflect.Type // the base type after all indirections + indir int // number of indirections to reach the base type + externalEnc int // xGob, xBinary, or xText + externalDec int // xGob, xBinary or xText + encIndir int8 // number of indirections to reach the receiver type; may be negative + decIndir int8 // number of indirections to reach the receiver type; may be negative +} + +// externalEncoding bits +const ( + xGob = 1 + iota // GobEncoder or GobDecoder + xBinary // encoding.BinaryMarshaler or encoding.BinaryUnmarshaler + xText // encoding.TextMarshaler or encoding.TextUnmarshaler +) var ( // Protected by an RWMutex because we read it a lot and write @@ -75,15 +83,34 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err error) { } ut.indir++ } - ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderInterfaceType) - ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderInterfaceType) + + if ok, indir := implementsInterface(ut.user, gobEncoderInterfaceType); ok { + ut.externalEnc, ut.encIndir = xGob, indir + } else if ok, indir := implementsInterface(ut.user, binaryMarshalerInterfaceType); ok { + ut.externalEnc, ut.encIndir = xBinary, indir + } else if ok, indir := implementsInterface(ut.user, textMarshalerInterfaceType); ok { + ut.externalEnc, ut.encIndir = xText, indir + } + + if ok, indir := implementsInterface(ut.user, gobDecoderInterfaceType); ok { + ut.externalDec, ut.decIndir = xGob, indir + } else if ok, indir := implementsInterface(ut.user, binaryUnmarshalerInterfaceType); ok { + ut.externalDec, ut.decIndir = xBinary, indir + } else if ok, indir := implementsInterface(ut.user, textUnmarshalerInterfaceType); ok { + ut.externalDec, ut.decIndir = xText, indir + } + userTypeCache[rt] = ut return } var ( - gobEncoderInterfaceType = reflect.TypeOf((*GobEncoder)(nil)).Elem() - gobDecoderInterfaceType = reflect.TypeOf((*GobDecoder)(nil)).Elem() + gobEncoderInterfaceType = reflect.TypeOf((*GobEncoder)(nil)).Elem() + gobDecoderInterfaceType = reflect.TypeOf((*GobDecoder)(nil)).Elem() + binaryMarshalerInterfaceType = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem() + binaryUnmarshalerInterfaceType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem() + textMarshalerInterfaceType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + textUnmarshalerInterfaceType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() ) // implementsInterface reports whether the type implements the @@ -412,7 +439,7 @@ func newStructType(name string) *structType { // works through typeIds and userTypeInfos alone. func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, error) { // Does this type implement GobEncoder? - if ut.isGobEncoder { + if ut.externalEnc != 0 { return newGobEncoderType(name), nil } var err error @@ -593,11 +620,13 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { // To maintain binary compatibility, if you extend this type, always put // the new fields last. type wireType struct { - ArrayT *arrayType - SliceT *sliceType - StructT *structType - MapT *mapType - GobEncoderT *gobEncoderType + ArrayT *arrayType + SliceT *sliceType + StructT *structType + MapT *mapType + GobEncoderT *gobEncoderType + BinaryMarshalerT *gobEncoderType + TextMarshalerT *gobEncoderType } func (w *wireType) string() string { @@ -616,6 +645,10 @@ func (w *wireType) string() string { return w.MapT.Name case w.GobEncoderT != nil: return w.GobEncoderT.Name + case w.BinaryMarshalerT != nil: + return w.BinaryMarshalerT.Name + case w.TextMarshalerT != nil: + return w.TextMarshalerT.Name } return unknown } @@ -631,7 +664,7 @@ var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock // typeLock must be held. func getTypeInfo(ut *userTypeInfo) (*typeInfo, error) { rt := ut.base - if ut.isGobEncoder { + if ut.externalEnc != 0 { // We want the user type, not the base type. rt = ut.user } @@ -646,12 +679,20 @@ func getTypeInfo(ut *userTypeInfo) (*typeInfo, error) { } info.id = gt.id() - if ut.isGobEncoder { + if ut.externalEnc != 0 { userType, err := getType(rt.Name(), ut, rt) if err != nil { return nil, err } - info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)} + gt := userType.id().gobType().(*gobEncoderType) + switch ut.externalEnc { + case xGob: + info.wire = &wireType{GobEncoderT: gt} + case xBinary: + info.wire = &wireType{BinaryMarshalerT: gt} + case xText: + info.wire = &wireType{TextMarshalerT: gt} + } typeInfoMap[ut.user] = info return info, nil } -- 2.48.1