]> Cypherpunks repositories - gostls13.git/commitdiff
changes for more restricted reflect.SetValue
authorRuss Cox <rsc@golang.org>
Mon, 18 Apr 2011 18:36:22 +0000 (14:36 -0400)
committerRuss Cox <rsc@golang.org>
Mon, 18 Apr 2011 18:36:22 +0000 (14:36 -0400)
R=golang-dev, r
CC=golang-dev
https://golang.org/cl/4423043

14 files changed:
src/cmd/gofmt/rewrite.go
src/pkg/asn1/asn1_test.go
src/pkg/gob/codec_test.go
src/pkg/gob/decode.go
src/pkg/gob/decoder.go
src/pkg/gob/encode.go
src/pkg/gob/encoder_test.go
src/pkg/json/decode.go
src/pkg/json/decode_test.go
src/pkg/netchan/export.go
src/pkg/netchan/import.go
src/pkg/rpc/server.go
src/pkg/testing/quick/quick.go
src/pkg/xml/read_test.go

index 47d1ac46ce0b4da9b1b1ce9401ca9e786bd7fda0..614296d6ab8bb085070968bf0c7f858dcd94b8b5 100644 (file)
@@ -276,21 +276,21 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value)
                return v
 
        case reflect.Struct:
-               v := reflect.Zero(p.Type())
+               v := reflect.New(p.Type()).Elem()
                for i := 0; i < p.NumField(); i++ {
                        v.Field(i).Set(subst(m, p.Field(i), pos))
                }
                return v
 
        case reflect.Ptr:
-               v := reflect.Zero(p.Type())
+               v := reflect.New(p.Type()).Elem()
                if elem := p.Elem(); elem.IsValid() {
                        v.Set(subst(m, elem, pos).Addr())
                }
                return v
 
        case reflect.Interface:
-               v := reflect.Zero(p.Type())
+               v := reflect.New(p.Type()).Elem()
                if elem := p.Elem(); elem.IsValid() {
                        v.Set(subst(m, elem, pos))
                }
index 018c534eb8f4911653c6196ab62056437025046b..d2a35b75ef71c54941d2290dd995fad2fe894674 100644 (file)
@@ -267,11 +267,6 @@ func TestParseFieldParameters(t *testing.T) {
        }
 }
 
-type unmarshalTest struct {
-       in  []byte
-       out interface{}
-}
-
 type TestObjectIdentifierStruct struct {
        OID ObjectIdentifier
 }
@@ -290,7 +285,10 @@ type TestElementsAfterString struct {
        A, B int
 }
 
-var unmarshalTestData []unmarshalTest = []unmarshalTest{
+var unmarshalTestData = []struct {
+       in  []byte
+       out interface{}
+}{
        {[]byte{0x02, 0x01, 0x42}, newInt(0x42)},
        {[]byte{0x30, 0x08, 0x06, 0x06, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d}, &TestObjectIdentifierStruct{[]int{1, 2, 840, 113549}}},
        {[]byte{0x03, 0x04, 0x06, 0x6e, 0x5d, 0xc0}, &BitString{[]byte{110, 93, 192}, 18}},
@@ -309,9 +307,7 @@ var unmarshalTestData []unmarshalTest = []unmarshalTest{
 
 func TestUnmarshal(t *testing.T) {
        for i, test := range unmarshalTestData {
-               pv := reflect.Zero(reflect.NewValue(test.out).Type())
-               zv := reflect.Zero(pv.Type().Elem())
-               pv.Set(zv.Addr())
+               pv := reflect.New(reflect.Typeof(test.out).Elem())
                val := pv.Interface()
                _, err := Unmarshal(test.in, val)
                if err != nil {
index 28042ccaa3a2db8a815eaa9ac079797fd6207eca..fc157da5f6dd1786a8c8543c18f57a584506a13d 100644 (file)
@@ -999,7 +999,6 @@ type Bad0 struct {
        C  float64
 }
 
-
 func TestInvalidField(t *testing.T) {
        var bad0 Bad0
        bad0.CH = make(chan int)
index aebe75e345dac4310dd726241d1fb7c92e2a6792..b6d7cbea81b3c6e20d2186be8ea6d4c2566f17d8 100644 (file)
@@ -581,7 +581,7 @@ func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, p uintpt
 // unlike the other items we can't use a pointer directly.
 func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
        instr := &decInstr{op, 0, indir, 0, ovfl}
-       up := unsafe.Pointer(v.UnsafeAddr())
+       up := unsafe.Pointer(unsafeAddr(v))
        if indir > 1 {
                up = decIndirect(up, indir)
        }
@@ -608,8 +608,8 @@ func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, p uintptr,
        v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p)))
        n := int(state.decodeUint())
        for i := 0; i < n; i++ {
-               key := decodeIntoValue(state, keyOp, keyIndir, reflect.Zero(mtyp.Key()), ovfl)
-               elem := decodeIntoValue(state, elemOp, elemIndir, reflect.Zero(mtyp.Elem()), ovfl)
+               key := decodeIntoValue(state, keyOp, keyIndir, allocValue(mtyp.Key()), ovfl)
+               elem := decodeIntoValue(state, elemOp, elemIndir, allocValue(mtyp.Elem()), ovfl)
                v.SetMapIndex(key, elem)
        }
 }
@@ -686,8 +686,8 @@ func setInterfaceValue(ivalue reflect.Value, value reflect.Value) {
 // Interfaces are encoded as the name of a concrete type followed by a value.
 // If the name is empty, the value is nil and no value is sent.
 func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p uintptr, indir int) {
-       // Create an interface reflect.Value.  We need one even for the nil case.
-       ivalue := reflect.Zero(ityp)
+       // Create a writable interface reflect.Value.  We need one even for the nil case.
+       ivalue := allocValue(ityp)
        // Read the name of the concrete type.
        b := make([]byte, state.decodeUint())
        state.b.Read(b)
@@ -712,7 +712,7 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui
        // in case we want to ignore the value by skipping it completely).
        state.decodeUint()
        // Read the concrete value.
-       value := reflect.Zero(typ)
+       value := allocValue(typ)
        dec.decodeValue(concreteId, value)
        if dec.err != nil {
                error(dec.err)
@@ -1209,9 +1209,9 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) {
                        name := base.Name()
                        errorf("gob: type mismatch: no fields matched compiling decoder for %s", name)
                }
-               dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), ut.indir)
+               dec.decodeStruct(engine, ut, uintptr(unsafeAddr(val)), ut.indir)
        } else {
-               dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr()))
+               dec.decodeSingle(engine, ut, uintptr(unsafeAddr(val)))
        }
 }
 
@@ -1256,3 +1256,26 @@ func init() {
        }
        decOpTable[reflect.Uintptr] = uop
 }
+
+// Gob assumes it can call UnsafeAddr on any Value
+// in order to get a pointer it can copy data from.
+// Values that have just been created and do not point
+// into existing structs or slices cannot be addressed,
+// so simulate it by returning a pointer to a copy.
+// Each call allocates once.
+func unsafeAddr(v reflect.Value) uintptr {
+       if v.CanAddr() {
+               return v.UnsafeAddr()
+       }
+       x := reflect.New(v.Type()).Elem()
+       x.Set(v)
+       return x.UnsafeAddr()
+}
+
+// Gob depends on being able to take the address
+// of zeroed Values it creates, so use this wrapper instead
+// of the standard reflect.Zero.
+// Each call allocates once.
+func allocValue(t reflect.Type) reflect.Value {
+       return reflect.New(t).Elem()
+}
index a631c27a2be710221b59cbd52ff7a3a270488791..f77bcd5afcd2f092647b3f8d6ec4607b9bd4a10d 100644 (file)
@@ -171,12 +171,18 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
        return dec.DecodeValue(value)
 }
 
-// DecodeValue reads the next value from the connection and stores
-// it in the data represented by the reflection value.
-// The value must be the correct type for the next
-// data item received, or it may be nil, which means the
-// value will be discarded.
-func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
+// DecodeValue reads the next value from the connection.
+// If v is the zero reflect.Value (v.Kind() == Invalid), DecodeValue discards the value.
+// Otherwise, it stores the value into v.  In that case, v must represent
+// a non-nil pointer to data or be an assignable reflect.Value (v.CanSet())
+func (dec *Decoder) DecodeValue(v reflect.Value) os.Error {
+       if v.IsValid() {
+               if v.Kind() == reflect.Ptr && !v.IsNil() {
+                       // That's okay, we'll store through the pointer.
+               } else if !v.CanSet() {
+                       return os.ErrorString("gob: DecodeValue of unassignable value")
+               }
+       }
        // Make sure we're single-threaded through here.
        dec.mutex.Lock()
        defer dec.mutex.Unlock()
@@ -185,7 +191,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
        dec.err = nil
        id := dec.decodeTypeSequence(false)
        if dec.err == nil {
-               dec.decodeValue(id, value)
+               dec.decodeValue(id, v)
        }
        return dec.err
 }
index 36bde08aa819200dd5951399b554f0e2fa056c17..2652fd221cb2c4ac30fcbd42f91cafe402e14214 100644 (file)
@@ -402,7 +402,7 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in
        if !v.IsValid() {
                errorf("gob: encodeReflectValue: nil element")
        }
-       op(nil, state, unsafe.Pointer(v.UnsafeAddr()))
+       op(nil, state, unsafe.Pointer(unsafeAddr(v)))
 }
 
 // encodeMap encodes a map as unsigned count followed by key:value pairs.
@@ -695,8 +695,8 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf
                value = reflect.Indirect(value)
        }
        if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct {
-               enc.encodeStruct(b, engine, value.UnsafeAddr())
+               enc.encodeStruct(b, engine, unsafeAddr(value))
        } else {
-               enc.encodeSingle(b, engine, value.UnsafeAddr())
+               enc.encodeSingle(b, engine, unsafeAddr(value))
        }
 }
index 3d5dfdb86ec13a95c2828ac968714c9fda8e4442..7b02a0b42ebfe883e92fff4c848c086c6c93a010 100644 (file)
@@ -170,7 +170,7 @@ func TestTypeToPtrType(t *testing.T) {
                A int
        }
        t0 := Type0{7}
-       t0p := (*Type0)(nil)
+       t0p := new(Type0)
        if err := encAndDec(t0, t0p); err != nil {
                t.Error(err)
        }
index a5fd33912e61086b25269d5fed97b78028327695..19896edc43446bc2eb60a43f208676019656f53b 100644 (file)
@@ -124,8 +124,7 @@ func (d *decodeState) unmarshal(v interface{}) (err os.Error) {
 
        rv := reflect.NewValue(v)
        pv := rv
-       if pv.Kind() != reflect.Ptr ||
-               pv.IsNil() {
+       if pv.Kind() != reflect.Ptr || pv.IsNil() {
                return &InvalidUnmarshalError{reflect.Typeof(v)}
        }
 
@@ -267,17 +266,17 @@ func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, refl
                        v = iv.Elem()
                        continue
                }
+
                pv := v
                if pv.Kind() != reflect.Ptr {
                        break
                }
 
-               if pv.Elem().Kind() != reflect.Ptr &&
-                       wantptr && !isUnmarshaler {
+               if pv.Elem().Kind() != reflect.Ptr && wantptr && pv.CanSet() && !isUnmarshaler {
                        return nil, pv
                }
                if pv.IsNil() {
-                       pv.Set(reflect.Zero(pv.Type().Elem()).Addr())
+                       pv.Set(reflect.New(pv.Type().Elem()))
                }
                if isUnmarshaler {
                        // Using v.Interface().(Unmarshaler)
@@ -443,6 +442,8 @@ func (d *decodeState) object(v reflect.Value) {
                return
        }
 
+       var mapElem reflect.Value
+
        for {
                // Read opening " of string key or closing }.
                op := d.scanWhile(scanSkipSpace)
@@ -466,7 +467,13 @@ func (d *decodeState) object(v reflect.Value) {
                // Figure out field corresponding to key.
                var subv reflect.Value
                if mv.IsValid() {
-                       subv = reflect.Zero(mv.Type().Elem())
+                       elemType := mv.Type().Elem()
+                       if !mapElem.IsValid() {
+                               mapElem = reflect.New(elemType).Elem()
+                       } else {
+                               mapElem.Set(reflect.Zero(elemType))
+                       }
+                       subv = mapElem
                } else {
                        var f reflect.StructField
                        var ok bool
index cf8f53bc4a281240dd8eb6c62f25fb2c5ea3c208..9e0f0b0d594ce130436aa7f659f4444c86e52f36 100644 (file)
@@ -138,8 +138,7 @@ func TestUnmarshal(t *testing.T) {
                        continue
                }
                // v = new(right-type)
-               v := reflect.NewValue(tt.ptr)
-               v.Set(reflect.Zero(v.Type().Elem()).Addr())
+               v := reflect.New(reflect.Typeof(tt.ptr).Elem())
                if err := Unmarshal([]byte(in), v.Interface()); !reflect.DeepEqual(err, tt.err) {
                        t.Errorf("#%d: %v want %v", i, err, tt.err)
                        continue
index 2209f04e8a4e160fb90cd7111d423080c11bca5f..dacee4f183356f44eaefd69b11a21c10a3c8974d 100644 (file)
@@ -221,7 +221,7 @@ func (client *expClient) serveSend(hdr header) {
                return
        }
        // Create a new value for each received item.
-       val := reflect.Zero(nch.ch.Type().Elem())
+       val := reflect.New(nch.ch.Type().Elem()).Elem()
        if err := client.decode(val); err != nil {
                expLog("value decode:", err, "; type ", nch.ch.Type())
                return
index 9921486bdf4b5cd5e6a8ccdbe1899e46284959f3..2f5ce58f845e7ee32767a307da8c3c9c64d5d726 100644 (file)
@@ -133,7 +133,7 @@ func (imp *Importer) run() {
                ackHdr.SeqNum = hdr.SeqNum
                imp.encode(ackHdr, payAck, nil)
                // Create a new value for each received item.
-               value := reflect.Zero(nch.ch.Type().Elem())
+               value := reflect.New(nch.ch.Type().Elem()).Elem()
                if e := imp.decode(value); e != nil {
                        impLog("importer value decode:", e)
                        return
index af31a65cc9c8b0c3426b06e11b21fc06b4d80371..d46bc4343e066f4abf070f025f5c52e61b763f9e 100644 (file)
@@ -297,12 +297,6 @@ type InvalidRequest struct{}
 
 var invalidRequest = InvalidRequest{}
 
-func _new(t reflect.Type) reflect.Value {
-       v := reflect.Zero(t)
-       v.Set(reflect.Zero(t.Elem()).Addr())
-       return v
-}
-
 func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
        resp := server.getResponse()
        // Encode the response header
@@ -411,8 +405,8 @@ func (server *Server) ServeCodec(codec ServerCodec) {
                }
 
                // Decode the argument value.
-               argv := _new(mtype.ArgType)
-               replyv := _new(mtype.ReplyType)
+               argv := reflect.New(mtype.ArgType.Elem())
+               replyv := reflect.New(mtype.ReplyType.Elem())
                err = codec.ReadRequestBody(argv.Interface())
                if err != nil {
                        if err == os.EOF || err == io.ErrUnexpectedEOF {
index 52fd38d9c8056e9543549d39278caa4606fd49c3..94450da988d3af68900440054b4ac6fb67b1a4c6 100644 (file)
@@ -107,8 +107,8 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) {
                if !ok {
                        return reflect.Value{}, false
                }
-               p := reflect.Zero(concrete)
-               p.Set(v.Addr())
+               p := reflect.New(concrete.Elem())
+               p.Elem().Set(v)
                return p, true
        case reflect.Slice:
                numElems := rand.Intn(complexSize)
@@ -129,7 +129,7 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) {
                }
                return reflect.NewValue(string(codePoints)), true
        case reflect.Struct:
-               s := reflect.Zero(t)
+               s := reflect.New(t).Elem()
                for i := 0; i < s.NumField(); i++ {
                        v, ok := Value(concrete.Field(i).Type, rand)
                        if !ok {
index 0e28e73a636ee9677ef9b005cf9f60fc988c6b67..3d1e5b8844a528b30e5cba4fdfacc687ca733d8b 100644 (file)
@@ -288,9 +288,7 @@ var pathTests = []interface{}{
 
 func TestUnmarshalPaths(t *testing.T) {
        for _, pt := range pathTests {
-               p := reflect.Zero(reflect.NewValue(pt).Type())
-               p.Set(reflect.Zero(p.Type().Elem()).Addr())
-               v := p.Interface()
+               v := reflect.New(reflect.Typeof(pt).Elem()).Interface()
                if err := Unmarshal(StringReader(pathTestString), v); err != nil {
                        t.Fatalf("Unmarshal: %s", err)
                }