]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: add option to use named parameter in query arguments
authorDaniel Theophanes <kardianos@gmail.com>
Mon, 3 Oct 2016 16:49:25 +0000 (09:49 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 17 Oct 2016 07:56:35 +0000 (07:56 +0000)
Modify the new Context methods to take a name-value driver struct.
This will require more modifications to drivers to use, but will
reduce the overall number of structures that need to be maintained
over time.

Fixes #12381

Change-Id: I30747533ce418a1be5991a0c8767a26e8451adbd
Reviewed-on: https://go-review.googlesource.com/30166
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/database/sql/convert.go
src/database/sql/ctxutil.go
src/database/sql/driver/driver.go
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index 99aed2398e28b099dae92b78589db468549135ac..cee96319da8548f6d17075127ee40e1fddb4c06c 100644 (file)
@@ -21,8 +21,8 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript
 // Stmt.Query into driver Values.
 //
 // The statement ds may be nil, if no statement is available.
-func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
-       dargs := make([]driver.Value, len(args))
+func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
+       nvargs := make([]driver.NamedValue, len(args))
        var si driver.Stmt
        if ds != nil {
                si = ds.si
@@ -33,16 +33,27 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
        if !ok {
                for n, arg := range args {
                        var err error
-                       dargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
+                       nvargs[n].Ordinal = n + 1
+                       if np, ok := arg.(NamedParam); ok {
+                               arg = np.Value
+                               nvargs[n].Name = np.Name
+                       }
+                       nvargs[n].Value, err = driver.DefaultParameterConverter.ConvertValue(arg)
+
                        if err != nil {
                                return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
                        }
                }
-               return dargs, nil
+               return nvargs, nil
        }
 
        // Let the Stmt convert its own arguments.
        for n, arg := range args {
+               nvargs[n].Ordinal = n + 1
+               if np, ok := arg.(NamedParam); ok {
+                       arg = np.Value
+                       nvargs[n].Name = np.Name
+               }
                // First, see if the value itself knows how to convert
                // itself to a driver type. For example, a NullString
                // struct changing into a string or nil.
@@ -66,18 +77,18 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
                // same error.
                var err error
                ds.Lock()
-               dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
+               nvargs[n].Value, err = cc.ColumnConverter(n).ConvertValue(arg)
                ds.Unlock()
                if err != nil {
                        return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
                }
-               if !driver.IsValue(dargs[n]) {
+               if !driver.IsValue(nvargs[n].Value) {
                        return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T",
-                               arg, dargs[n])
+                               arg, nvargs[n].Value)
                }
        }
 
-       return dargs, nil
+       return nvargs, nil
 }
 
 // convertAssign copies to dest the value in src, converting it if possible.
index e1d4c03c9a3f09e69a5c594d7c6385851d3aecff..173f6a9d2bcfa5cb0135f720d8f31df5ab807f7f 100644 (file)
@@ -50,9 +50,13 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver
        }
 }
 
-func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, dargs []driver.Value) (driver.Result, error) {
+func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
        if execerCtx, is := execer.(driver.ExecerContext); is {
-               return execerCtx.ExecContext(ctx, query, dargs)
+               return execerCtx.ExecContext(ctx, query, nvdargs)
+       }
+       dargs, err := namedValueToValue(nvdargs)
+       if err != nil {
+               return nil, err
        }
        if ctx.Done() == context.Background().Done() {
                return execer.Exec(query, dargs)
@@ -90,9 +94,13 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, darg
        }
 }
 
-func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, dargs []driver.Value) (driver.Rows, error) {
+func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
        if queryerCtx, is := queryer.(driver.QueryerContext); is {
-               return queryerCtx.QueryContext(ctx, query, dargs)
+               return queryerCtx.QueryContext(ctx, query, nvdargs)
+       }
+       dargs, err := namedValueToValue(nvdargs)
+       if err != nil {
+               return nil, err
        }
        if ctx.Done() == context.Background().Done() {
                return queryer.Query(query, dargs)
@@ -130,9 +138,13 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, d
        }
 }
 
-func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Result, error) {
+func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
        if siCtx, is := si.(driver.StmtExecContext); is {
-               return siCtx.ExecContext(ctx, dargs)
+               return siCtx.ExecContext(ctx, nvdargs)
+       }
+       dargs, err := namedValueToValue(nvdargs)
+       if err != nil {
+               return nil, err
        }
        if ctx.Done() == context.Background().Done() {
                return si.Exec(dargs)
@@ -170,9 +182,13 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value
        }
 }
 
-func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Rows, error) {
+func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
        if siCtx, is := si.(driver.StmtQueryContext); is {
-               return siCtx.QueryContext(ctx, dargs)
+               return siCtx.QueryContext(ctx, nvdargs)
+       }
+       dargs, err := namedValueToValue(nvdargs)
+       if err != nil {
+               return nil, err
        }
        if ctx.Done() == context.Background().Done() {
                return si.Query(dargs)
@@ -253,3 +269,14 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
                return r.txi, r.err
        }
 }
+
+func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
+       dargs := make([]driver.Value, len(named))
+       for n, param := range named {
+               if len(param.Name) > 0 {
+                       return nil, errors.New("sql: driver does not support the use of Named Parameters")
+               }
+               dargs[n] = param.Value
+       }
+       return dargs, nil
+}
index b3d83f3ff449a0b3f32a323d7ef5e678a8dccdbc..6cc970f688c446a3041f6ba8a5923b8f2cf57f3a 100644 (file)
@@ -24,6 +24,16 @@ import (
 //   time.Time
 type Value interface{}
 
+// NamedValue holds both the value name and value.
+// The Ordinal is the position of the parameter starting from one and is always set.
+// If the Name is not empty it should be used for the parameter identifier and
+// not the ordinal position.
+type NamedValue struct {
+       Name    string
+       Ordinal int
+       Value   Value
+}
+
 // Driver is the interface that must be implemented by a database
 // driver.
 type Driver interface {
@@ -71,7 +81,7 @@ type Execer interface {
 // ExecerContext is like execer, but must honor the context timeout and return
 // when the context is cancelled.
 type ExecerContext interface {
-       ExecContext(ctx context.Context, query string, args []Value) (Result, error)
+       ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
 }
 
 // Queryer is an optional interface that may be implemented by a Conn.
@@ -88,7 +98,7 @@ type Queryer interface {
 // QueryerContext is like Queryer, but most honor the context timeout and return
 // when the context is cancelled.
 type QueryerContext interface {
-       QueryContext(ctx context.Context, query string, args []Value) (Rows, error)
+       QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
 }
 
 // Conn is a connection to a database. It is not used concurrently
@@ -174,13 +184,13 @@ type Stmt interface {
 // StmtExecContext enhances the Stmt interface by providing Exec with context.
 type StmtExecContext interface {
        // ExecContext must honor the context timeout and return when it is cancelled.
-       ExecContext(ctx context.Context, args []Value) (Result, error)
+       ExecContext(ctx context.Context, args []NamedValue) (Result, error)
 }
 
 // StmtQueryContext enhances the Stmt interface by providing Query with context.
 type StmtQueryContext interface {
        // QueryContext must honor the context timeout and return when it is cancelled.
-       QueryContext(ctx context.Context, args []Value) (Rows, error)
+       QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
 }
 
 // ColumnConverter may be optionally implemented by Stmt if the
index aaa13a679964d4242c7fc27b89c14ee92552d647..07f50196a57c751fa4a35121740644ca9b424dec 100644 (file)
@@ -5,6 +5,7 @@
 package sql
 
 import (
+       "context"
        "database/sql/driver"
        "errors"
        "fmt"
@@ -32,6 +33,7 @@ var _ = log.Printf
 //     where types are: "string", [u]int{8,16,32,64}, "bool"
 //   INSERT|<tablename>|col=val,col2=val2,col3=?
 //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
+//   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
 //
 // Any of these can be preceded by PANIC|<method>|, to cause the
 // named method on fakeStmt to panic.
@@ -103,6 +105,12 @@ type fakeTx struct {
        c *fakeConn
 }
 
+type boundCol struct {
+       Column      string
+       Placeholder string
+       Ordinal     int
+}
+
 type fakeStmt struct {
        c *fakeConn
        q string // just for debugging
@@ -120,7 +128,7 @@ type fakeStmt struct {
        colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
        placeholders int           // used by INSERT/SELECT: number of ? params
 
-       whereCol []string // used by SELECT (all placeholders)
+       whereCol []boundCol // used by SELECT (all placeholders)
 
        placeholderConverter []driver.ValueConverter // used by INSERT
 }
@@ -339,18 +347,23 @@ func (c *fakeConn) Close() (err error) {
        return nil
 }
 
-func checkSubsetTypes(args []driver.Value) error {
-       for n, arg := range args {
-               switch arg.(type) {
+func checkSubsetTypes(args []driver.NamedValue) error {
+       for _, arg := range args {
+               switch arg.Value.(type) {
                case int64, float64, bool, nil, []byte, string, time.Time:
                default:
-                       return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
+                       return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
                }
        }
        return nil
 }
 
 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+       // Ensure that ExecContext is called if available.
+       panic("ExecContext was not called.")
+}
+
+func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
        // This is an optional interface, but it's implemented here
        // just to check that all the args are of the proper types.
        // ErrSkip is returned so the caller acts as if we didn't
@@ -363,6 +376,11 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error
 }
 
 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+       // Ensure that ExecContext is called if available.
+       panic("QueryContext was not called.")
+}
+
+func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
        // This is an optional interface, but it's implemented here
        // just to check that all the args are of the proper types.
        // ErrSkip is returned so the caller acts as if we didn't
@@ -403,13 +421,13 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err
                        stmt.Close()
                        return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
                }
-               if value != "?" {
+               if !strings.HasPrefix(value, "?") {
                        stmt.Close()
                        return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
                                stmt.table, column)
                }
-               stmt.whereCol = append(stmt.whereCol, column)
                stmt.placeholders++
+               stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
        }
        return stmt, nil
 }
@@ -454,7 +472,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err
                }
                stmt.colName = append(stmt.colName, column)
 
-               if value != "?" {
+               if !strings.HasPrefix(value, "?") {
                        var subsetVal interface{}
                        // Convert to driver subset type
                        switch ctype {
@@ -477,7 +495,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err
                } else {
                        stmt.placeholders++
                        stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
-                       stmt.colValue = append(stmt.colValue, "?")
+                       stmt.colValue = append(stmt.colValue, value)
                }
        }
        return stmt, nil
@@ -580,6 +598,9 @@ var errClosed = errors.New("fakedb: statement has been closed")
 var hookExecBadConn func() bool
 
 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
+       panic("Using ExecContext")
+}
+func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
        if s.panic == "Exec" {
                panic(s.panic)
        }
@@ -620,7 +641,7 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
 // When doInsert is true, add the row to the table.
 // When doInsert is false do prep-work and error checking, but don't
 // actually add the row to the table.
-func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) {
+func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
        db := s.c.db
        if len(args) != s.placeholders {
                panic("error in pkg db; should only get here if size is correct")
@@ -646,8 +667,18 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
                        return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
                }
                var val interface{}
-               if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
-                       val = args[argPos]
+               if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
+                       if strvalue == "?" {
+                               val = args[argPos].Value
+                       } else {
+                               // Assign value from argument placeholder name.
+                               for _, a := range args {
+                                       if a.Name == strvalue {
+                                               val = a.Value
+                                               break
+                                       }
+                               }
+                       }
                        argPos++
                } else {
                        val = s.colValue[n]
@@ -667,6 +698,10 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
 var hookQueryBadConn func() bool
 
 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
+       panic("Use QueryContext")
+}
+
+func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
        if s.panic == "Query" {
                panic(s.panic)
        }
@@ -700,9 +735,9 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
                }
 
                if s.table == "magicquery" {
-                       if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
-                               if args[0] == "sleep" {
-                                       time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
+                       if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
+                               if args[0].Value == "sleep" {
+                                       time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
                                }
                        }
                }
@@ -725,8 +760,8 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
                        // Process the where clause, skipping non-match rows. This is lazy
                        // and just uses fmt.Sprintf("%v") to test equality. Good enough
                        // for test code.
-                       for widx, wcol := range s.whereCol {
-                               idx := t.columnIndex(wcol)
+                       for _, wcol := range s.whereCol {
+                               idx := t.columnIndex(wcol.Column)
                                if idx == -1 {
                                        t.mu.Unlock()
                                        return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
@@ -736,7 +771,19 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
                                        // lazy hack to avoid sprintf %v on a []byte
                                        tcol = string(bs)
                                }
-                               if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
+                               var argValue interface{}
+                               if wcol.Placeholder == "?" {
+                                       argValue = args[wcol.Ordinal-1].Value
+                               } else {
+                                       // Assign arg value from placeholder name.
+                                       for _, a := range args {
+                                               if a.Name == wcol.Placeholder {
+                                                       argValue = a.Value
+                                                       break
+                                               }
+                                       }
+                               }
+                               if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
                                        continue rows
                                }
                        }
index 970334269ddf8b1540ffc357cac43e338c4a3e10..616acb2be1a5210e71b0960a48675afdf6a5629c 100644 (file)
@@ -67,6 +67,27 @@ func Drivers() []string {
        return list
 }
 
+// NamedParam may be passed into query parameter arguments to associate
+// a named placeholder with a value.
+type NamedParam struct {
+       // Name of the parameter placeholder. If empty the ordinal position in the
+       // argument list will be used.
+       Name string
+
+       // Value of the parameter. It may be assigned the same value types as
+       // the query arguments.
+       Value interface{}
+}
+
+// Param provides a more concise way to create NamedParam values.
+func Param(name string, value interface{}) NamedParam {
+       // This method exists because the go1compat promise
+       // doesn't guarantee that structs don't grow more fields,
+       // so unkeyed struct literals are a vet error. Thus, we don't
+       // want to encourage sql.NamedParam{name, value}.
+       return NamedParam{Name: name, Value: value}
+}
+
 // RawBytes is a byte slice that holds a reference to memory owned by
 // the database itself. After a Scan into a RawBytes, the slice is only
 // valid until the next call to Next, Scan, or Close.
@@ -1064,7 +1085,7 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate
        }()
 
        if execer, ok := dc.ci.(driver.Execer); ok {
-               var dargs []driver.Value
+               var dargs []driver.NamedValue
                dargs, err = driverArgs(nil, args)
                if err != nil {
                        return nil, err
index bce210da97000d1acbd1f26d989de1cfcebec176..885cadf3c66f9c54c8aceae3629791fb833ce2aa 100644 (file)
@@ -395,6 +395,53 @@ func TestMultiResultSetQuery(t *testing.T) {
        }
 }
 
+func TestQueryNamedParam(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+       prepares0 := numPrepares(t, db)
+       rows, err := db.Query(
+               // Ensure the name and age parameters only match on placeholder name, not position.
+               "SELECT|people|age,name|name=?name,age=?age",
+               Param("?age", 2),
+               Param("?name", "Bob"),
+       )
+       if err != nil {
+               t.Fatalf("Query: %v", err)
+       }
+       type row struct {
+               age  int
+               name string
+       }
+       got := []row{}
+       for rows.Next() {
+               var r row
+               err = rows.Scan(&r.age, &r.name)
+               if err != nil {
+                       t.Fatalf("Scan: %v", err)
+               }
+               got = append(got, r)
+       }
+       err = rows.Err()
+       if err != nil {
+               t.Fatalf("Err: %v", err)
+       }
+       want := []row{
+               {age: 2, name: "Bob"},
+       }
+       if !reflect.DeepEqual(got, want) {
+               t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+       }
+
+       // And verify that the final rows.Next() call, which hit EOF,
+       // also closed the rows connection.
+       if n := db.numFreeConns(); n != 1 {
+               t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+       }
+       if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+               t.Errorf("executed %d Prepare statements; want 1", prepares)
+       }
+}
+
 func TestByteOwnership(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)