]> Cypherpunks repositories - gostls13.git/commitdiff
gob: add support for complex numbers
authorRob Pike <r@golang.org>
Thu, 24 Jun 2010 22:07:28 +0000 (15:07 -0700)
committerRob Pike <r@golang.org>
Thu, 24 Jun 2010 22:07:28 +0000 (15:07 -0700)
R=rsc
CC=golang-dev
https://golang.org/cl/1708048

src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/encode.go
src/pkg/gob/type.go

index dad0ac48c169909a0d59aafab3cd5c0929a6d292..d8bdf2d2f467e491fa7792ef7f559fdee509fcd7 100644 (file)
@@ -112,7 +112,9 @@ var boolResult = []byte{0x07, 0x01}
 var signedResult = []byte{0x07, 2 * 17}
 var unsignedResult = []byte{0x07, 17}
 var floatResult = []byte{0x07, 0xFE, 0x31, 0x40}
-// The result of encoding "hello" with field number 6
+// The result of encoding a number 17+19i with field number 7
+var complexResult = []byte{0x07, 0xFE, 0x31, 0x40, 0xFE, 0x33, 0x40}
+// The result of encoding "hello" with field number 7
 var bytesResult = []byte{0x07, 0x05, 'h', 'e', 'l', 'l', 'o'}
 
 func newencoderState(b *bytes.Buffer) *encoderState {
@@ -537,6 +539,45 @@ func TestScalarDecInstructions(t *testing.T) {
                }
        }
 
+       // complex
+       {
+               var data struct {
+                       a complex
+               }
+               instr := &decInstr{decOpMap[reflect.Complex], 6, 0, 0, ovfl}
+               state := newDecodeStateFromData(complexResult)
+               execDec("complex", instr, state, t, unsafe.Pointer(&data))
+               if data.a != 17+19i {
+                       t.Errorf("complex a = %v not 17+19i", data.a)
+               }
+       }
+
+       // complex64
+       {
+               var data struct {
+                       a complex64
+               }
+               instr := &decInstr{decOpMap[reflect.Complex64], 6, 0, 0, ovfl}
+               state := newDecodeStateFromData(complexResult)
+               execDec("complex", instr, state, t, unsafe.Pointer(&data))
+               if data.a != 17+19i {
+                       t.Errorf("complex a = %v not 17+19i", data.a)
+               }
+       }
+
+       // complex128
+       {
+               var data struct {
+                       a complex128
+               }
+               instr := &decInstr{decOpMap[reflect.Complex128], 6, 0, 0, ovfl}
+               state := newDecodeStateFromData(complexResult)
+               execDec("complex", instr, state, t, unsafe.Pointer(&data))
+               if data.a != 17+19i {
+                       t.Errorf("complex a = %v not 17+19i", data.a)
+               }
+       }
+
        // bytes == []uint8
        {
                var data struct {
@@ -576,6 +617,7 @@ func TestEndToEnd(t *testing.T) {
                n       *[3]float
                strs    *[2]string
                int64s  *[]int64
+               ri      complex64
                s       string
                y       []byte
                t       *T2
@@ -590,6 +632,7 @@ func TestEndToEnd(t *testing.T) {
                n:      &[3]float{1.5, 2.5, 3.5},
                strs:   &[2]string{s1, s2},
                int64s: &[]int64{77, 89, 123412342134},
+               ri:     17 - 23i,
                s:      "Now is the time",
                y:      []byte("hello, sailor"),
                t:      &T2{"this is T2"},
@@ -616,6 +659,8 @@ func TestOverflow(t *testing.T) {
                maxu uint64
                maxf float64
                minf float64
+               maxc complex128
+               minc complex128
        }
        var it inputT
        var err os.Error
@@ -758,6 +803,22 @@ func TestOverflow(t *testing.T) {
        if err == nil || err.String() != `value for "maxf" out of range` {
                t.Error("wrong overflow error for float32:", err)
        }
+
+       // complex64
+       b.Reset()
+       it = inputT{
+               maxc: cmplx(math.MaxFloat32*2, math.MaxFloat32*2),
+       }
+       type outc64 struct {
+               maxc complex64
+               minc complex64
+       }
+       var o8 outc64
+       enc.Encode(it)
+       err = dec.Decode(&o8)
+       if err == nil || err.String() != `value for "maxc" out of range` {
+               t.Error("wrong overflow error for complex64:", err)
+       }
 }
 
 
index 8f5c383ea925b23243ab328bcb3e83d8166c499c..0dbf81488770ff9a774d2304cde296e3f47167b0 100644 (file)
@@ -151,6 +151,11 @@ func ignoreUint(i *decInstr, state *decodeState, p unsafe.Pointer) {
        decodeUint(state)
 }
 
+func ignoreTwoUints(i *decInstr, state *decodeState, p unsafe.Pointer) {
+       decodeUint(state)
+       decodeUint(state)
+}
+
 func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
@@ -286,25 +291,30 @@ func floatFromBits(u uint64) float64 {
        return math.Float64frombits(v)
 }
 
-func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
-       if i.indir > 0 {
-               if *(*unsafe.Pointer)(p) == nil {
-                       *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32))
-               }
-               p = *(*unsafe.Pointer)(p)
-       }
+func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
        v := floatFromBits(decodeUint(state))
        av := v
        if av < 0 {
                av = -av
        }
-       if math.MaxFloat32 < av { // underflow is OK
+       // +Inf is OK in both 32- and 64-bit floats.  Underflow is always OK.
+       if math.MaxFloat32 < av && av <= math.MaxFloat64 {
                state.err = i.ovfl
        } else {
                *(*float32)(p) = float32(v)
        }
 }
 
+func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
+       if i.indir > 0 {
+               if *(*unsafe.Pointer)(p) == nil {
+                       *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32))
+               }
+               p = *(*unsafe.Pointer)(p)
+       }
+       storeFloat32(i, state, p)
+}
+
 func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) {
        if i.indir > 0 {
                if *(*unsafe.Pointer)(p) == nil {
@@ -315,6 +325,30 @@ func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) {
        *(*float64)(p) = floatFromBits(uint64(decodeUint(state)))
 }
 
+// Complex numbers are just a pair of floating-point numbers, real part first.
+func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) {
+       if i.indir > 0 {
+               if *(*unsafe.Pointer)(p) == nil {
+                       *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex64))
+               }
+               p = *(*unsafe.Pointer)(p)
+       }
+       storeFloat32(i, state, p)
+       storeFloat32(i, state, unsafe.Pointer(uintptr(p)+uintptr(unsafe.Sizeof(float(0)))))
+}
+
+func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) {
+       if i.indir > 0 {
+               if *(*unsafe.Pointer)(p) == nil {
+                       *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex128))
+               }
+               p = *(*unsafe.Pointer)(p)
+       }
+       real := floatFromBits(uint64(decodeUint(state)))
+       imag := floatFromBits(uint64(decodeUint(state)))
+       *(*complex128)(p) = cmplx(real, imag)
+}
+
 // uint8 arrays are encoded as an unsigned count followed by the raw bytes.
 func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
        if i.indir > 0 {
@@ -540,21 +574,25 @@ func ignoreSlice(state *decodeState, elemOp decOp) os.Error {
        return ignoreArrayHelper(state, elemOp, int(decodeUint(state)))
 }
 
+// Index by Go types.
 var decOpMap = []decOp{
-       reflect.Bool:    decBool,
-       reflect.Int8:    decInt8,
-       reflect.Int16:   decInt16,
-       reflect.Int32:   decInt32,
-       reflect.Int64:   decInt64,
-       reflect.Uint8:   decUint8,
-       reflect.Uint16:  decUint16,
-       reflect.Uint32:  decUint32,
-       reflect.Uint64:  decUint64,
-       reflect.Float32: decFloat32,
-       reflect.Float64: decFloat64,
-       reflect.String:  decString,
-}
-
+       reflect.Bool:       decBool,
+       reflect.Int8:       decInt8,
+       reflect.Int16:      decInt16,
+       reflect.Int32:      decInt32,
+       reflect.Int64:      decInt64,
+       reflect.Uint8:      decUint8,
+       reflect.Uint16:     decUint16,
+       reflect.Uint32:     decUint32,
+       reflect.Uint64:     decUint64,
+       reflect.Float32:    decFloat32,
+       reflect.Float64:    decFloat64,
+       reflect.Complex64:  decComplex64,
+       reflect.Complex128: decComplex128,
+       reflect.String:     decString,
+}
+
+// Indexed by gob types.  tComplex will be added during type.init().
 var decIgnoreOpMap = map[typeId]decOp{
        tBool:   ignoreUint,
        tInt:    ignoreUint,
@@ -652,6 +690,8 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
                // Special cases
                wire := dec.wireType[wireId]
                switch {
+               case wire == nil:
+                       panic("internal error: can't find ignore op for type " + wireId.string())
                case wire.arrayT != nil:
                        elemId := wire.arrayT.Elem
                        elemOp, err := dec.decIgnoreOpFor(elemId)
@@ -728,6 +768,8 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
                return fw == tUint
        case *reflect.FloatType:
                return fw == tFloat
+       case *reflect.ComplexType:
+               return fw == tComplex
        case *reflect.StringType:
                return fw == tString
        case *reflect.ArrayType:
@@ -866,39 +908,39 @@ func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
 }
 
 func init() {
-       // We assume that the size of float is sufficient to tell us whether it is
-       // equivalent to float32 or to float64.   This is very unlikely to be wrong.
-       var op decOp
-       switch unsafe.Sizeof(float(0)) {
-       case unsafe.Sizeof(float32(0)):
-               op = decFloat32
-       case unsafe.Sizeof(float64(0)):
-               op = decFloat64
+       var fop, cop decOp
+       switch reflect.Typeof(float(0)).Bits() {
+       case 32:
+               fop = decFloat32
+               cop = decComplex64
+       case 64:
+               fop = decFloat64
+               cop = decComplex128
        default:
                panic("gob: unknown size of float")
        }
-       decOpMap[reflect.Float] = op
+       decOpMap[reflect.Float] = fop
+       decOpMap[reflect.Complex] = cop
 
-       // A similar assumption about int and uint.  Also assume int and uint have the same size.
-       var uop decOp
-       switch unsafe.Sizeof(int(0)) {
-       case unsafe.Sizeof(int32(0)):
-               op = decInt32
+       var iop, uop decOp
+       switch reflect.Typeof(int(0)).Bits() {
+       case 32:
+               iop = decInt32
                uop = decUint32
-       case unsafe.Sizeof(int64(0)):
-               op = decInt64
+       case 64:
+               iop = decInt64
                uop = decUint64
        default:
                panic("gob: unknown size of int/uint")
        }
-       decOpMap[reflect.Int] = op
+       decOpMap[reflect.Int] = iop
        decOpMap[reflect.Uint] = uop
 
        // Finally uintptr
-       switch unsafe.Sizeof(uintptr(0)) {
-       case unsafe.Sizeof(uint32(0)):
+       switch reflect.Typeof(uintptr(0)).Bits() {
+       case 32:
                uop = decUint32
-       case unsafe.Sizeof(uint64(0)):
+       case 64:
                uop = decUint64
        default:
                panic("gob: unknown size of uintptr")
index 76032389e3ba8daf003c7fa473b0cc255ab652ed..b48c1f698a9a7cd172c54a0d50ec5296f26ae6b8 100644 (file)
@@ -467,7 +467,7 @@ func floatBits(f float64) uint64 {
 }
 
 func encFloat(i *encInstr, state *encoderState, p unsafe.Pointer) {
-       f := float(*(*float)(p))
+       f := *(*float)(p)
        if f != 0 || state.inArray {
                v := floatBits(float64(f))
                state.update(i)
@@ -476,7 +476,7 @@ func encFloat(i *encInstr, state *encoderState, p unsafe.Pointer) {
 }
 
 func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) {
-       f := float32(*(*float32)(p))
+       f := *(*float32)(p)
        if f != 0 || state.inArray {
                v := floatBits(float64(f))
                state.update(i)
@@ -493,6 +493,40 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) {
        }
 }
 
+// Complex numbers are just a pair of floating-point numbers, real part first.
+func encComplex(i *encInstr, state *encoderState, p unsafe.Pointer) {
+       c := *(*complex)(p)
+       if c != 0+0i || state.inArray {
+               rpart := floatBits(float64(real(c)))
+               ipart := floatBits(float64(imag(c)))
+               state.update(i)
+               encodeUint(state, rpart)
+               encodeUint(state, ipart)
+       }
+}
+
+func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) {
+       c := *(*complex64)(p)
+       if c != 0+0i || state.inArray {
+               rpart := floatBits(float64(real(c)))
+               ipart := floatBits(float64(imag(c)))
+               state.update(i)
+               encodeUint(state, rpart)
+               encodeUint(state, ipart)
+       }
+}
+
+func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) {
+       c := *(*complex128)(p)
+       if c != 0+0i || state.inArray {
+               rpart := floatBits(real(c))
+               ipart := floatBits(imag(c))
+               state.update(i)
+               encodeUint(state, rpart)
+               encodeUint(state, ipart)
+       }
+}
+
 // Byte arrays are encoded as an unsigned count followed by the raw bytes.
 func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) {
        b := *(*[]byte)(p)
@@ -602,22 +636,25 @@ func encodeMap(b *bytes.Buffer, rt reflect.Type, p uintptr, keyOp, elemOp encOp,
 }
 
 var encOpMap = []encOp{
-       reflect.Bool:    encBool,
-       reflect.Int:     encInt,
-       reflect.Int8:    encInt8,
-       reflect.Int16:   encInt16,
-       reflect.Int32:   encInt32,
-       reflect.Int64:   encInt64,
-       reflect.Uint:    encUint,
-       reflect.Uint8:   encUint8,
-       reflect.Uint16:  encUint16,
-       reflect.Uint32:  encUint32,
-       reflect.Uint64:  encUint64,
-       reflect.Uintptr: encUintptr,
-       reflect.Float:   encFloat,
-       reflect.Float32: encFloat32,
-       reflect.Float64: encFloat64,
-       reflect.String:  encString,
+       reflect.Bool:       encBool,
+       reflect.Int:        encInt,
+       reflect.Int8:       encInt8,
+       reflect.Int16:      encInt16,
+       reflect.Int32:      encInt32,
+       reflect.Int64:      encInt64,
+       reflect.Uint:       encUint,
+       reflect.Uint8:      encUint8,
+       reflect.Uint16:     encUint16,
+       reflect.Uint32:     encUint32,
+       reflect.Uint64:     encUint64,
+       reflect.Uintptr:    encUintptr,
+       reflect.Float:      encFloat,
+       reflect.Float32:    encFloat32,
+       reflect.Float64:    encFloat64,
+       reflect.Complex:    encComplex,
+       reflect.Complex64:  encComplex64,
+       reflect.Complex128: encComplex128,
+       reflect.String:     encString,
 }
 
 // Return the encoding op for the base type under rt and
@@ -688,7 +725,7 @@ func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
                }
        }
        if op == nil {
-               return op, indir, os.ErrorString("gob enc: can't happen: encode type" + rt.String())
+               return op, indir, os.ErrorString("gob enc: can't happen: encode type " + rt.String())
        }
        return op, indir, nil
 }
index 6a3e6ba65891e423a963f123f70dbcfaf74fd222..2ad36ae653875dba327275d880e6f139726176a5 100644 (file)
@@ -78,12 +78,17 @@ func (t *commonType) Name() string { return t.name }
 // Create and check predefined types
 // The string for tBytes is "bytes" not "[]byte" to signify its specialness.
 
-var tBool = bootstrapType("bool", false, 1)
-var tInt = bootstrapType("int", int(0), 2)
-var tUint = bootstrapType("uint", uint(0), 3)
-var tFloat = bootstrapType("float", float64(0), 4)
-var tBytes = bootstrapType("bytes", make([]byte, 0), 5)
-var tString = bootstrapType("string", "", 6)
+var (
+       // Primordial types, needed during initialization.
+       tBool   = bootstrapType("bool", false, 1)
+       tInt    = bootstrapType("int", int(0), 2)
+       tUint   = bootstrapType("uint", uint(0), 3)
+       tFloat  = bootstrapType("float", float64(0), 4)
+       tBytes  = bootstrapType("bytes", make([]byte, 0), 5)
+       tString = bootstrapType("string", "", 6)
+       // Types added to the language later, not needed during initialization.
+       tComplex typeId
+)
 
 // Predefined because it's needed by the Decoder
 var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id
@@ -94,10 +99,17 @@ func init() {
        checkId(9, mustGetTypeInfo(reflect.Typeof(commonType{})).id)
        checkId(11, mustGetTypeInfo(reflect.Typeof(structType{})).id)
        checkId(12, mustGetTypeInfo(reflect.Typeof(fieldType{})).id)
+
+       // Complex was added after gob was written, so appears after the
+       // fundamental types are built.
+       tComplex = bootstrapType("complex", 0+0i, 15)
+       decIgnoreOpMap[tComplex] = ignoreTwoUints
+
        builtinIdToType = make(map[typeId]gobType)
        for k, v := range idToType {
                builtinIdToType[k] = v
        }
+
        // Move the id space upwards to allow for growth in the predefined world
        // without breaking existing files.
        if nextId > firstUserId {
@@ -241,6 +253,9 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
        case *reflect.FloatType:
                return tFloat.gobType(), nil
 
+       case *reflect.ComplexType:
+               return tComplex.gobType(), nil
+
        case *reflect.StringType:
                return tString.gobType(), nil
 
@@ -335,7 +350,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
        rt := reflect.Typeof(e)
        _, present := types[rt]
        if present {
-               panic("bootstrap type already present: " + name)
+               panic("bootstrap type already present: " + name + ", " + rt.String())
        }
        typ := &commonType{name: name}
        types[rt] = typ