]> Cypherpunks repositories - gostls13.git/commitdiff
reflect: implement Set(nil), SetValue(nil) for PtrValue and MapValue
authorRuss Cox <rsc@golang.org>
Wed, 21 Apr 2010 00:02:08 +0000 (17:02 -0700)
committerRuss Cox <rsc@golang.org>
Wed, 21 Apr 2010 00:02:08 +0000 (17:02 -0700)
R=r
CC=golang-dev
https://golang.org/cl/823048

src/pkg/reflect/all_test.go
src/pkg/reflect/value.go

index 67bfe9eaf0b2e989a5302d0f613e7964452c64f1..552b09d89a6d89faf6b0845bafde43f725fa3bef 100644 (file)
@@ -350,6 +350,26 @@ func TestPtrPointTo(t *testing.T) {
        }
 }
 
+func TestPtrSetNil(t *testing.T) {
+       var i int32 = 1234
+       ip := &i
+       vip := NewValue(&ip)
+       vip.(*PtrValue).Elem().(*PtrValue).Set(nil)
+       if ip != nil {
+               t.Errorf("got non-nil (%d), want nil", *ip)
+       }
+}
+
+func TestMapSetNil(t *testing.T) {
+       m := make(map[string]int)
+       vm := NewValue(&m)
+       vm.(*PtrValue).Elem().(*MapValue).Set(nil)
+       if m != nil {
+               t.Errorf("got non-nil (%p), want nil", m)
+       }
+}
+
+
 func TestAll(t *testing.T) {
        testType(t, 1, Typeof((int8)(0)), "int8")
        testType(t, 2, Typeof((*int8)(nil)).(*PtrType).Elem(), "int8")
@@ -838,6 +858,12 @@ func TestMap(t *testing.T) {
        if ok {
                t.Errorf("newm[\"a\"] = %d after delete", v)
        }
+
+       mv = NewValue(&m).(*PtrValue).Elem().(*MapValue)
+       mv.Set(nil)
+       if m != nil {
+               t.Errorf("mv.Set(nil) failed")
+       }
 }
 
 func TestChan(t *testing.T) {
index d8ddb289a4075351b7af660616bf8f54169fdc58..7730fefc38dcb787c8f2fe07fa59daefe31e131a 100644 (file)
@@ -1038,12 +1038,22 @@ func (v *MapValue) Set(x *MapValue) {
        if !v.canSet {
                panic(cannotSet)
        }
+       if x == nil {
+               *(**uintptr)(v.addr) = nil
+               return
+       }
        typesMustMatch(v.typ, x.typ)
        *(*uintptr)(v.addr) = *(*uintptr)(x.addr)
 }
 
 // Set sets v to the value x.
-func (v *MapValue) SetValue(x Value) { v.Set(x.(*MapValue)) }
+func (v *MapValue) SetValue(x Value) {
+       if x == nil {
+               v.Set(nil)
+               return
+       }
+       v.Set(x.(*MapValue))
+}
 
 // Get returns the uintptr value of v.
 // It is mainly useful for printing.
@@ -1146,6 +1156,10 @@ func (v *PtrValue) Get() uintptr { return *(*uintptr)(v.addr) }
 // Set assigns x to v.
 // The new value x must have the same type as v.
 func (v *PtrValue) Set(x *PtrValue) {
+       if x == nil {
+               *(**uintptr)(v.addr) = nil
+               return
+       }
        if !v.canSet {
                panic(cannotSet)
        }
@@ -1156,7 +1170,13 @@ func (v *PtrValue) Set(x *PtrValue) {
 }
 
 // Set sets v to the value x.
-func (v *PtrValue) SetValue(x Value) { v.Set(x.(*PtrValue)) }
+func (v *PtrValue) SetValue(x Value) {
+       if x == nil {
+               v.Set(nil)
+               return
+       }
+       v.Set(x.(*PtrValue))
+}
 
 // PointTo changes v to point to x.
 func (v *PtrValue) PointTo(x Value) {