]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: support scanning into user defined string types
authorDaniel Theophanes <kardianos@gmail.com>
Thu, 30 Mar 2017 23:03:03 +0000 (16:03 -0700)
committerDaniel Theophanes <kardianos@gmail.com>
Fri, 31 Mar 2017 05:02:02 +0000 (05:02 +0000)
User defined numeric types such as "type Int int64" have
been able to be scanned into without a custom scanner by
using the reflect scan code path used to convert between
various numeric types. Add in a path for string types
for symmetry and least surprise.

Fixes #18101

Change-Id: I00553bcf021ffe6d95047eca0067ee94b54ff501
Reviewed-on: https://go-review.googlesource.com/39031
Run-TryBot: Daniel Theophanes <kardianos@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/database/sql/convert.go
src/database/sql/convert_test.go
src/database/sql/sql_test.go

index ea2f377810eac95f479066cb70024f8bc28f25dc..630a585ab223dc6de4f718977b4d905cf56100ac 100644 (file)
@@ -270,6 +270,11 @@ func convertAssign(dest, src interface{}) error {
                return nil
        }
 
+       // The following conversions use a string value as an intermediate representation
+       // to convert between various numeric types.
+       //
+       // This also allows scanning into user defined types such as "type Int int64".
+       // For symmetry, also check for string destination types.
        switch dv.Kind() {
        case reflect.Ptr:
                if src == nil {
@@ -306,6 +311,15 @@ func convertAssign(dest, src interface{}) error {
                }
                dv.SetFloat(f64)
                return nil
+       case reflect.String:
+               switch v := src.(type) {
+               case string:
+                       dv.SetString(v)
+                       return nil
+               case []byte:
+                       dv.SetString(string(v))
+                       return nil
+               }
        }
 
        return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
index 4dfab1f6bef6a8d400ed92634e0c8900354bc078..853a12ce959f38d0953948f7202e7c8df1d84063 100644 (file)
@@ -17,9 +17,11 @@ import (
 var someTime = time.Unix(123, 0)
 var answer int64 = 42
 
-type userDefined float64
-
-type userDefinedSlice []int
+type (
+       userDefined       float64
+       userDefinedSlice  []int
+       userDefinedString string
+)
 
 type conversionTest struct {
        s, d interface{} // source and destination
@@ -39,6 +41,7 @@ type conversionTest struct {
        wantptr    *int64 // if non-nil, *d's pointed value must be equal to *wantptr
        wantnil    bool   // if true, *d must be *int64(nil)
        wantusrdef userDefined
+       wantusrstr userDefinedString
 }
 
 // Target variables for scanning into.
@@ -171,6 +174,7 @@ var conversionTests = []conversionTest{
        {s: int64(123), d: new(userDefined), wantusrdef: 123},
        {s: "1.5", d: new(userDefined), wantusrdef: 1.5},
        {s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`},
+       {s: "str", d: new(userDefinedString), wantusrstr: "str"},
 
        // Other errors
        {s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`},
@@ -260,6 +264,9 @@ func TestConversions(t *testing.T) {
                if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) {
                        errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined))
                }
+               if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) {
+                       errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString))
+               }
        }
 }
 
index 4c1adf51b61ed988beb8a62fff611828967d2d1e..f511aa4ac3553c02d4ce1ad1e0b3dfa8c5a68618 100644 (file)
@@ -3155,6 +3155,24 @@ func TestPing(t *testing.T) {
        }
 }
 
+// Issue 18101.
+func TestTypedString(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+
+       type Str string
+       var scanned Str
+
+       err := db.QueryRow("SELECT|people|name|name=?", "Alice").Scan(&scanned)
+       if err != nil {
+               t.Fatal(err)
+       }
+       expected := Str("Alice")
+       if scanned != expected {
+               t.Errorf("expected %+v, got %+v", expected, scanned)
+       }
+}
+
 func BenchmarkConcurrentDBExec(b *testing.B) {
        b.ReportAllocs()
        ct := new(concurrentDBExecTest)