Modify the new Context methods to take a name-value driver struct.
This will require more modifications to drivers to use, but will
reduce the overall number of structures that need to be maintained
over time.
Fixes #12381
Change-Id: I30747533ce418a1be5991a0c8767a26e8451adbd
Reviewed-on: https://go-review.googlesource.com/30166
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
// Stmt.Query into driver Values.
//
// The statement ds may be nil, if no statement is available.
-func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
- dargs := make([]driver.Value, len(args))
+func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
+ nvargs := make([]driver.NamedValue, len(args))
var si driver.Stmt
if ds != nil {
si = ds.si
if !ok {
for n, arg := range args {
var err error
- dargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
+ nvargs[n].Ordinal = n + 1
+ if np, ok := arg.(NamedParam); ok {
+ arg = np.Value
+ nvargs[n].Name = np.Name
+ }
+ nvargs[n].Value, err = driver.DefaultParameterConverter.ConvertValue(arg)
+
if err != nil {
return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
}
}
- return dargs, nil
+ return nvargs, nil
}
// Let the Stmt convert its own arguments.
for n, arg := range args {
+ nvargs[n].Ordinal = n + 1
+ if np, ok := arg.(NamedParam); ok {
+ arg = np.Value
+ nvargs[n].Name = np.Name
+ }
// First, see if the value itself knows how to convert
// itself to a driver type. For example, a NullString
// struct changing into a string or nil.
// same error.
var err error
ds.Lock()
- dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
+ nvargs[n].Value, err = cc.ColumnConverter(n).ConvertValue(arg)
ds.Unlock()
if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
}
- if !driver.IsValue(dargs[n]) {
+ if !driver.IsValue(nvargs[n].Value) {
return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T",
- arg, dargs[n])
+ arg, nvargs[n].Value)
}
}
- return dargs, nil
+ return nvargs, nil
}
// convertAssign copies to dest the value in src, converting it if possible.
}
}
-func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, dargs []driver.Value) (driver.Result, error) {
+func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
if execerCtx, is := execer.(driver.ExecerContext); is {
- return execerCtx.ExecContext(ctx, query, dargs)
+ return execerCtx.ExecContext(ctx, query, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
}
if ctx.Done() == context.Background().Done() {
return execer.Exec(query, dargs)
}
}
-func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, dargs []driver.Value) (driver.Rows, error) {
+func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
if queryerCtx, is := queryer.(driver.QueryerContext); is {
- return queryerCtx.QueryContext(ctx, query, dargs)
+ return queryerCtx.QueryContext(ctx, query, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
}
if ctx.Done() == context.Background().Done() {
return queryer.Query(query, dargs)
}
}
-func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Result, error) {
+func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
if siCtx, is := si.(driver.StmtExecContext); is {
- return siCtx.ExecContext(ctx, dargs)
+ return siCtx.ExecContext(ctx, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
}
if ctx.Done() == context.Background().Done() {
return si.Exec(dargs)
}
}
-func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Rows, error) {
+func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
if siCtx, is := si.(driver.StmtQueryContext); is {
- return siCtx.QueryContext(ctx, dargs)
+ return siCtx.QueryContext(ctx, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
}
if ctx.Done() == context.Background().Done() {
return si.Query(dargs)
return r.txi, r.err
}
}
+
+func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
+ dargs := make([]driver.Value, len(named))
+ for n, param := range named {
+ if len(param.Name) > 0 {
+ return nil, errors.New("sql: driver does not support the use of Named Parameters")
+ }
+ dargs[n] = param.Value
+ }
+ return dargs, nil
+}
// time.Time
type Value interface{}
+// NamedValue holds both the value name and value.
+// The Ordinal is the position of the parameter starting from one and is always set.
+// If the Name is not empty it should be used for the parameter identifier and
+// not the ordinal position.
+type NamedValue struct {
+ Name string
+ Ordinal int
+ Value Value
+}
+
// Driver is the interface that must be implemented by a database
// driver.
type Driver interface {
// ExecerContext is like execer, but must honor the context timeout and return
// when the context is cancelled.
type ExecerContext interface {
- ExecContext(ctx context.Context, query string, args []Value) (Result, error)
+ ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
}
// Queryer is an optional interface that may be implemented by a Conn.
// QueryerContext is like Queryer, but most honor the context timeout and return
// when the context is cancelled.
type QueryerContext interface {
- QueryContext(ctx context.Context, query string, args []Value) (Rows, error)
+ QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
}
// Conn is a connection to a database. It is not used concurrently
// StmtExecContext enhances the Stmt interface by providing Exec with context.
type StmtExecContext interface {
// ExecContext must honor the context timeout and return when it is cancelled.
- ExecContext(ctx context.Context, args []Value) (Result, error)
+ ExecContext(ctx context.Context, args []NamedValue) (Result, error)
}
// StmtQueryContext enhances the Stmt interface by providing Query with context.
type StmtQueryContext interface {
// QueryContext must honor the context timeout and return when it is cancelled.
- QueryContext(ctx context.Context, args []Value) (Rows, error)
+ QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
}
// ColumnConverter may be optionally implemented by Stmt if the
package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
// where types are: "string", [u]int{8,16,32,64}, "bool"
// INSERT|<tablename>|col=val,col2=val2,col3=?
// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
+// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
//
// Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic.
c *fakeConn
}
+type boundCol struct {
+ Column string
+ Placeholder string
+ Ordinal int
+}
+
type fakeStmt struct {
c *fakeConn
q string // just for debugging
colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
placeholders int // used by INSERT/SELECT: number of ? params
- whereCol []string // used by SELECT (all placeholders)
+ whereCol []boundCol // used by SELECT (all placeholders)
placeholderConverter []driver.ValueConverter // used by INSERT
}
return nil
}
-func checkSubsetTypes(args []driver.Value) error {
- for n, arg := range args {
- switch arg.(type) {
+func checkSubsetTypes(args []driver.NamedValue) error {
+ for _, arg := range args {
+ switch arg.Value.(type) {
case int64, float64, bool, nil, []byte, string, time.Time:
default:
- return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
+ return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
}
}
return nil
}
func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ // Ensure that ExecContext is called if available.
+ panic("ExecContext was not called.")
+}
+
+func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
}
func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+ // Ensure that ExecContext is called if available.
+ panic("QueryContext was not called.")
+}
+
+func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
stmt.Close()
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
}
- if value != "?" {
+ if !strings.HasPrefix(value, "?") {
stmt.Close()
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
stmt.table, column)
}
- stmt.whereCol = append(stmt.whereCol, column)
stmt.placeholders++
+ stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
}
return stmt, nil
}
}
stmt.colName = append(stmt.colName, column)
- if value != "?" {
+ if !strings.HasPrefix(value, "?") {
var subsetVal interface{}
// Convert to driver subset type
switch ctype {
} else {
stmt.placeholders++
stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
- stmt.colValue = append(stmt.colValue, "?")
+ stmt.colValue = append(stmt.colValue, value)
}
}
return stmt, nil
var hookExecBadConn func() bool
func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
+ panic("Using ExecContext")
+}
+func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if s.panic == "Exec" {
panic(s.panic)
}
// When doInsert is true, add the row to the table.
// When doInsert is false do prep-work and error checking, but don't
// actually add the row to the table.
-func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) {
+func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
db := s.c.db
if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct")
return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
}
var val interface{}
- if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
- val = args[argPos]
+ if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
+ if strvalue == "?" {
+ val = args[argPos].Value
+ } else {
+ // Assign value from argument placeholder name.
+ for _, a := range args {
+ if a.Name == strvalue {
+ val = a.Value
+ break
+ }
+ }
+ }
argPos++
} else {
val = s.colValue[n]
var hookQueryBadConn func() bool
func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
+ panic("Use QueryContext")
+}
+
+func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if s.panic == "Query" {
panic(s.panic)
}
}
if s.table == "magicquery" {
- if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
- if args[0] == "sleep" {
- time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
+ if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
+ if args[0].Value == "sleep" {
+ time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
}
}
}
// Process the where clause, skipping non-match rows. This is lazy
// and just uses fmt.Sprintf("%v") to test equality. Good enough
// for test code.
- for widx, wcol := range s.whereCol {
- idx := t.columnIndex(wcol)
+ for _, wcol := range s.whereCol {
+ idx := t.columnIndex(wcol.Column)
if idx == -1 {
t.mu.Unlock()
return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
// lazy hack to avoid sprintf %v on a []byte
tcol = string(bs)
}
- if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
+ var argValue interface{}
+ if wcol.Placeholder == "?" {
+ argValue = args[wcol.Ordinal-1].Value
+ } else {
+ // Assign arg value from placeholder name.
+ for _, a := range args {
+ if a.Name == wcol.Placeholder {
+ argValue = a.Value
+ break
+ }
+ }
+ }
+ if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
continue rows
}
}
return list
}
+// NamedParam may be passed into query parameter arguments to associate
+// a named placeholder with a value.
+type NamedParam struct {
+ // Name of the parameter placeholder. If empty the ordinal position in the
+ // argument list will be used.
+ Name string
+
+ // Value of the parameter. It may be assigned the same value types as
+ // the query arguments.
+ Value interface{}
+}
+
+// Param provides a more concise way to create NamedParam values.
+func Param(name string, value interface{}) NamedParam {
+ // This method exists because the go1compat promise
+ // doesn't guarantee that structs don't grow more fields,
+ // so unkeyed struct literals are a vet error. Thus, we don't
+ // want to encourage sql.NamedParam{name, value}.
+ return NamedParam{Name: name, Value: value}
+}
+
// RawBytes is a byte slice that holds a reference to memory owned by
// the database itself. After a Scan into a RawBytes, the slice is only
// valid until the next call to Next, Scan, or Close.
}()
if execer, ok := dc.ci.(driver.Execer); ok {
- var dargs []driver.Value
+ var dargs []driver.NamedValue
dargs, err = driverArgs(nil, args)
if err != nil {
return nil, err
}
}
+func TestQueryNamedParam(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+ rows, err := db.Query(
+ // Ensure the name and age parameters only match on placeholder name, not position.
+ "SELECT|people|age,name|name=?name,age=?age",
+ Param("?age", 2),
+ Param("?name", "Bob"),
+ )
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+ for rows.Next() {
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got = append(got, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want := []row{
+ {age: 2, name: "Bob"},
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)