]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: treat pointers as nullable types like encoding/json
authorAndrew Pritchard <awpritchard@gmail.com>
Wed, 8 Feb 2012 06:14:15 +0000 (17:14 +1100)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 8 Feb 2012 06:14:15 +0000 (17:14 +1100)
- convert from nil pointers to the nil interface{}
- dereference non-nil pointers
- convert from nil interface{}s to nil pointers
- allocate pointers for non-nil interface{}s
- tests for all of the above

R=golang-dev, bradfitz, rsc, rogpeppe
CC=golang-dev
https://golang.org/cl/5630052

src/pkg/database/sql/convert.go
src/pkg/database/sql/convert_test.go
src/pkg/database/sql/driver/types.go
src/pkg/database/sql/driver/types_test.go

index 4924ac14e468c0102beb7abcc436b86e5b09648e..31ff47f721da1905c439b8ed346e11e566b219cd 100644 (file)
@@ -110,6 +110,14 @@ func convertAssign(dest, src interface{}) error {
        }
 
        switch dv.Kind() {
+       case reflect.Ptr:
+               if src == nil {
+                       dv.Set(reflect.Zero(dv.Type()))
+                       return nil
+               } else {
+                       dv.Set(reflect.New(dv.Type().Elem()))
+                       return convertAssign(dv.Interface(), src)
+               }
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                s := asString(src)
                i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
index 34ee93987fc2cbd75748398097703aed8847f403..9c362d7320a17bf7c3e759abb92bfa0b143c5756 100644 (file)
@@ -13,6 +13,7 @@ import (
 )
 
 var someTime = time.Unix(123, 0)
+var answer int64 = 42
 
 type conversionTest struct {
        s, d interface{} // source and destination
@@ -27,6 +28,8 @@ type conversionTest struct {
        wantbool  bool // used if d is of type *bool
        wanterr   string
        wantiface interface{}
+       wantptr   *int64 // if non-nil, *d's pointed value must be equal to *wantptr
+       wantnil   bool   // if true, *d must be *int64(nil)
 }
 
 // Target variables for scanning into.
@@ -42,6 +45,7 @@ var (
        scanf32    float32
        scanf64    float64
        scantime   time.Time
+       scanptr    *int64
        scaniface  interface{}
 )
 
@@ -98,6 +102,10 @@ var conversionTests = []conversionTest{
        {s: "1.5", d: &scanf32, wantf32: float32(1.5)},
        {s: "1.5", d: &scanf64, wantf64: float64(1.5)},
 
+       // Pointers
+       {s: interface{}(nil), d: &scanptr, wantnil: true},
+       {s: int64(42), d: &scanptr, wantptr: &answer},
+
        // To interface{}
        {s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
        {s: int64(1), d: &scaniface, wantiface: int64(1)},
@@ -107,6 +115,10 @@ var conversionTests = []conversionTest{
        {s: nil, d: &scaniface},
 }
 
+func intPtrValue(intptr interface{}) interface{} {
+       return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int()
+}
+
 func intValue(intptr interface{}) int64 {
        return reflect.Indirect(reflect.ValueOf(intptr)).Int()
 }
@@ -162,6 +174,16 @@ func TestConversions(t *testing.T) {
                if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
                        errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
                }
+               if ct.wantnil && *ct.d.(**int64) != nil {
+                       errf("want nil, got %v", intPtrValue(ct.d))
+               }
+               if ct.wantptr != nil {
+                       if *ct.d.(**int64) == nil {
+                               errf("want pointer to %v, got nil", *ct.wantptr)
+                       } else if *ct.wantptr != intPtrValue(ct.d) {
+                               errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d))
+                       }
+               }
                if ifptr, ok := ct.d.(*interface{}); ok {
                        if !reflect.DeepEqual(ct.wantiface, scaniface) {
                                errf("want interface %#v, got %#v", ct.wantiface, scaniface)
index f38388523119b1731f33e398d76c4cfac26ffc97..ce3c943ead27170b7105f260e97cd0cf407b1647 100644 (file)
@@ -248,6 +248,13 @@ func (defaultConverter) ConvertValue(v interface{}) (interface{}, error) {
 
        rv := reflect.ValueOf(v)
        switch rv.Kind() {
+       case reflect.Ptr:
+               // indirect pointers
+               if rv.IsNil() {
+                       return nil, nil
+               } else {
+                       return defaultConverter{}.ConvertValue(rv.Elem().Interface())
+               }
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                return rv.Int(), nil
        case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
index 966bc6b45877eaf303e4bc5ec40aaff477637221..ab82bca7166e3ce7b466c317e91c6f015bcf7723 100644 (file)
@@ -18,6 +18,7 @@ type valueConverterTest struct {
 }
 
 var now = time.Now()
+var answer int64 = 42
 
 var valueConverterTests = []valueConverterTest{
        {Bool, "true", true, ""},
@@ -37,6 +38,9 @@ var valueConverterTests = []valueConverterTest{
        {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"},
        {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"},
        {DefaultParameterConverter, now, now, ""},
+       {DefaultParameterConverter, (*int64)(nil), nil, ""},
+       {DefaultParameterConverter, &answer, answer, ""},
+       {DefaultParameterConverter, &now, now, ""},
 }
 
 func TestValueConverters(t *testing.T) {