]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: use driver.ColumnConverter everywhere consistently
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 29 May 2012 18:09:09 +0000 (11:09 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 29 May 2012 18:09:09 +0000 (11:09 -0700)
It was only being used for (*Stmt).Exec, not Query, and not for
the same two methods on *DB.

This unifies (*Stmt).Exec's old inline code into the old
subsetArgs function, renaming it in the process (changing the
old word "subset" to "driver", mostly converted earlier)

Fixes #3640

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6258045

src/pkg/database/sql/convert.go
src/pkg/database/sql/fakedb_test.go
src/pkg/database/sql/sql.go
src/pkg/database/sql/sql_test.go

index bfcb03ccf8d42ac627f6d207d0e0f773f78353e2..964dc1848504830aaa33ff3a55a25a5a2802c5c8 100644 (file)
@@ -14,19 +14,61 @@ import (
        "strconv"
 )
 
-// subsetTypeArgs takes a slice of arguments from callers of the sql
-// package and converts them into a slice of the driver package's
-// "subset types".
-func subsetTypeArgs(args []interface{}) ([]driver.Value, error) {
-       out := make([]driver.Value, len(args))
+// driverArgs converts arguments from callers of Stmt.Exec and
+// Stmt.Query into driver Values.
+//
+// The statement si may be nil, if no statement is available.
+func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) {
+       dargs := make([]driver.Value, len(args))
+       cc, ok := si.(driver.ColumnConverter)
+
+       // Normal path, for a driver.Stmt that is not a ColumnConverter.
+       if !ok {
+               for n, arg := range args {
+                       var err error
+                       dargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
+                       if err != nil {
+                               return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
+                       }
+               }
+               return dargs, nil
+       }
+
+       // Let the Stmt convert its own arguments.
        for n, arg := range args {
+               // First, see if the value itself knows how to convert
+               // itself to a driver type.  For example, a NullString
+               // struct changing into a string or nil.
+               if svi, ok := arg.(driver.Valuer); ok {
+                       sv, err := svi.Value()
+                       if err != nil {
+                               return nil, fmt.Errorf("sql: argument index %d from Value: %v", n, err)
+                       }
+                       if !driver.IsValue(sv) {
+                               return nil, fmt.Errorf("sql: argument index %d: non-subset type %T returned from Value", n, sv)
+                       }
+                       arg = sv
+               }
+
+               // Second, ask the column to sanity check itself. For
+               // example, drivers might use this to make sure that
+               // an int64 values being inserted into a 16-bit
+               // integer field is in range (before getting
+               // truncated), or that a nil can't go into a NOT NULL
+               // column before going across the network to get the
+               // same error.
                var err error
-               out[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
+               dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
                if err != nil {
-                       return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err)
+                       return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
+               }
+               if !driver.IsValue(dargs[n]) {
+                       return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T",
+                               arg, dargs[n])
                }
        }
-       return out, nil
+
+       return dargs, nil
 }
 
 // convertAssign copies to dest the value in src, converting it if possible.
index 184e7756c51e4228e89e420393519bff05361a81..833e8bf4f514836cbd0d42bdb2594d839edb55f0 100644 (file)
@@ -383,6 +383,9 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
 }
 
 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
+       if len(s.placeholderConverter) == 0 {
+               return driver.DefaultParameterConverter
+       }
        return s.placeholderConverter[idx]
 }
 
@@ -598,6 +601,28 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
        return nil
 }
 
+// fakeDriverString is like driver.String, but indirects pointers like
+// DefaultValueConverter.
+//
+// This could be surprising behavior to retroactively apply to
+// driver.String now that Go1 is out, but this is convenient for
+// our TestPointerParamsAndScans.
+//
+type fakeDriverString struct{}
+
+func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
+       switch c := v.(type) {
+       case string, []byte:
+               return v, nil
+       case *string:
+               if c == nil {
+                       return nil, nil
+               }
+               return *c, nil
+       }
+       return fmt.Sprintf("%v", v), nil
+}
+
 func converterForType(typ string) driver.ValueConverter {
        switch typ {
        case "bool":
@@ -607,9 +632,9 @@ func converterForType(typ string) driver.ValueConverter {
        case "int32":
                return driver.Int32
        case "string":
-               return driver.NotNull{Converter: driver.String}
+               return driver.NotNull{Converter: fakeDriverString{}}
        case "nullstring":
-               return driver.Null{Converter: driver.String}
+               return driver.Null{Converter: fakeDriverString{}}
        case "int64":
                // TODO(coopernurse): add type-specific converter
                return driver.NotNull{Converter: driver.DefaultParameterConverter}
index 89136ef6e40f1cc473ddabbbc0a91976e5b9b153..b87f57f92f3c614bdcfce2f76115fb8047f9fc38 100644 (file)
@@ -326,13 +326,10 @@ func (db *DB) prepare(query string) (stmt *Stmt, err error) {
 
 // Exec executes a query without returning any rows.
 func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
-       sargs, err := subsetTypeArgs(args)
-       if err != nil {
-               return nil, err
-       }
        var res Result
+       var err error
        for i := 0; i < 10; i++ {
-               res, err = db.exec(query, sargs)
+               res, err = db.exec(query, args)
                if err != driver.ErrBadConn {
                        break
                }
@@ -340,7 +337,7 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
        return res, err
 }
 
-func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) {
+func (db *DB) exec(query string, args []interface{}) (res Result, err error) {
        ci, err := db.conn()
        if err != nil {
                return nil, err
@@ -348,7 +345,11 @@ func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) {
        defer db.putConn(ci, err)
 
        if execer, ok := ci.(driver.Execer); ok {
-               resi, err := execer.Exec(query, sargs)
+               dargs, err := driverArgs(nil, args)
+               if err != nil {
+                       return nil, err
+               }
+               resi, err := execer.Exec(query, dargs)
                if err != driver.ErrSkip {
                        if err != nil {
                                return nil, err
@@ -363,7 +364,12 @@ func (db *DB) exec(query string, sargs []driver.Value) (res Result, err error) {
        }
        defer sti.Close()
 
-       resi, err := sti.Exec(sargs)
+       dargs, err := driverArgs(sti, args)
+       if err != nil {
+               return nil, err
+       }
+
+       resi, err := sti.Exec(dargs)
        if err != nil {
                return nil, err
        }
@@ -577,13 +583,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
        }
        defer tx.releaseConn()
 
-       sargs, err := subsetTypeArgs(args)
-       if err != nil {
-               return nil, err
-       }
-
        if execer, ok := ci.(driver.Execer); ok {
-               resi, err := execer.Exec(query, sargs)
+               dargs, err := driverArgs(nil, args)
+               if err != nil {
+                       return nil, err
+               }
+               resi, err := execer.Exec(query, dargs)
                if err == nil {
                        return result{resi}, nil
                }
@@ -598,7 +603,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
        }
        defer sti.Close()
 
-       resi, err := sti.Exec(sargs)
+       dargs, err := driverArgs(sti, args)
+       if err != nil {
+               return nil, err
+       }
+
+       resi, err := sti.Exec(dargs)
        if err != nil {
                return nil, err
        }
@@ -674,51 +684,12 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
                return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
        }
 
-       sargs := make([]driver.Value, len(args))
-
-       // Convert args to subset types.
-       if cc, ok := si.(driver.ColumnConverter); ok {
-               for n, arg := range args {
-                       // First, see if the value itself knows how to convert
-                       // itself to a driver type.  For example, a NullString
-                       // struct changing into a string or nil.
-                       if svi, ok := arg.(driver.Valuer); ok {
-                               sv, err := svi.Value()
-                               if err != nil {
-                                       return nil, fmt.Errorf("sql: argument index %d from Value: %v", n, err)
-                               }
-                               if !driver.IsValue(sv) {
-                                       return nil, fmt.Errorf("sql: argument index %d: non-subset type %T returned from Value", n, sv)
-                               }
-                               arg = sv
-                       }
-
-                       // Second, ask the column to sanity check itself. For
-                       // example, drivers might use this to make sure that
-                       // an int64 values being inserted into a 16-bit
-                       // integer field is in range (before getting
-                       // truncated), or that a nil can't go into a NOT NULL
-                       // column before going across the network to get the
-                       // same error.
-                       sargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
-                       if err != nil {
-                               return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
-                       }
-                       if !driver.IsValue(sargs[n]) {
-                               return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T",
-                                       arg, sargs[n])
-                       }
-               }
-       } else {
-               for n, arg := range args {
-                       sargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
-                       if err != nil {
-                               return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
-                       }
-               }
+       dargs, err := driverArgs(si, args)
+       if err != nil {
+               return nil, err
        }
 
-       resi, err := si.Exec(sargs)
+       resi, err := si.Exec(dargs)
        if err != nil {
                return nil, err
        }
@@ -805,11 +776,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
        if want := si.NumInput(); want != -1 && len(args) != want {
                return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", si.NumInput(), len(args))
        }
-       sargs, err := subsetTypeArgs(args)
+
+       dargs, err := driverArgs(si, args)
        if err != nil {
                return nil, err
        }
-       rowsi, err := si.Query(sargs)
+
+       rowsi, err := si.Query(dargs)
        if err != nil {
                releaseConn(err)
                return nil, err
index b296705865f10c935efe541ae771a41de0beb8a1..1bfb59020b4380eda25deadda398deda4090f1c7 100644 (file)
@@ -306,8 +306,8 @@ func TestExec(t *testing.T) {
                {[]interface{}{7, 9}, ""},
 
                // Invalid conversions:
-               {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting Exec argument #1's type: sql/driver: value 4294967295 overflows int32"},
-               {[]interface{}{"Brad", "strconv fail"}, "sql: converting Exec argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"},
+               {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument #1's type: sql/driver: value 4294967295 overflows int32"},
+               {[]interface{}{"Brad", "strconv fail"}, "sql: converting argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"},
 
                // Wrong number of args:
                {[]interface{}{}, "sql: expected 2 arguments, got 0"},