]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: allow drivers to only implement Context variants
authorDaniel Theophanes <kardianos@gmail.com>
Sun, 24 Sep 2017 02:38:32 +0000 (19:38 -0700)
committerDaniel Theophanes <kardianos@gmail.com>
Tue, 24 Oct 2017 16:51:29 +0000 (16:51 +0000)
Drivers shouldn't need to implement both Queryer and QueryerContext,
they should just implement QueryerContext. Same with Execer and
ExecerContext. This CL tests for QueryContext and ExecerContext
first so drivers do not need to implement Queryer and Execer
with an empty definition.

Fixes #21663

Change-Id: Ifbaa71da669f4bc60f8da8c41a04a4afed699a9f
Reviewed-on: https://go-review.googlesource.com/65733
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/database/sql/ctxutil.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index b73ee86594257927622c5964ac4243482a707a85..170ec7d8a021171d65c0ad439078ab4eb4dcdbab 100644 (file)
@@ -26,8 +26,8 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver
        return si, err
 }
 
-func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
-       if execerCtx, is := execer.(driver.ExecerContext); is {
+func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
+       if execerCtx != nil {
                return execerCtx.ExecContext(ctx, query, nvdargs)
        }
        dargs, err := namedValueToValue(nvdargs)
@@ -43,10 +43,9 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda
        return execer.Exec(query, dargs)
 }
 
-func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
-       if queryerCtx, is := queryer.(driver.QueryerContext); is {
-               ret, err := queryerCtx.QueryContext(ctx, query, nvdargs)
-               return ret, err
+func ctxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
+       if queryerCtx != nil {
+               return queryerCtx.QueryContext(ctx, query, nvdargs)
        }
        dargs, err := namedValueToValue(nvdargs)
        if err != nil {
index 49d352fbf53cec4116b8969ea917ad71c6eb8bed..7c357106882d28dc973143ef59b4d1861656b4f2 100644 (file)
@@ -1276,15 +1276,20 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
        defer func() {
                release(err)
        }()
-       if execer, ok := dc.ci.(driver.Execer); ok {
-               var dargs []driver.NamedValue
-               dargs, err = driverArgs(dc.ci, nil, args)
+       execerCtx, ok := dc.ci.(driver.ExecerContext)
+       var execer driver.Execer
+       if !ok {
+               execer, ok = dc.ci.(driver.Execer)
+       }
+       if ok {
+               var nvdargs []driver.NamedValue
+               nvdargs, err = driverArgs(dc.ci, nil, args)
                if err != nil {
                        return nil, err
                }
                var resi driver.Result
                withLock(dc, func() {
-                       resi, err = ctxDriverExec(ctx, execer, query, dargs)
+                       resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
                })
                if err != driver.ErrSkip {
                        if err != nil {
@@ -1343,15 +1348,20 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat
 // The ctx context is from a query method and the txctx context is from an
 // optional transaction context.
 func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
-       if queryer, ok := dc.ci.(driver.Queryer); ok {
-               dargs, err := driverArgs(dc.ci, nil, args)
+       queryerCtx, ok := dc.ci.(driver.QueryerContext)
+       var queryer driver.Queryer
+       if !ok {
+               queryer, ok = dc.ci.(driver.Queryer)
+       }
+       if ok {
+               nvdargs, err := driverArgs(dc.ci, nil, args)
                if err != nil {
                        releaseConn(err)
                        return nil, err
                }
                var rowsi driver.Rows
                withLock(dc, func() {
-                       rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
+                       rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
                })
                if err != driver.ErrSkip {
                        if err != nil {
index 046d95aff47bb9231756bc6b77c6fec1eaef20e6..3551366369c60c40f6d6bf372f69f565862cfe91 100644 (file)
@@ -3487,6 +3487,127 @@ func TestNamedValueCheckerSkip(t *testing.T) {
        }
 }
 
+type ctxOnlyDriver struct {
+       fakeDriver
+}
+
+func (d *ctxOnlyDriver) Open(dsn string) (driver.Conn, error) {
+       conn, err := d.fakeDriver.Open(dsn)
+       if err != nil {
+               return nil, err
+       }
+       return &ctxOnlyConn{fc: conn.(*fakeConn)}, nil
+}
+
+var (
+       _ driver.Conn           = &ctxOnlyConn{}
+       _ driver.QueryerContext = &ctxOnlyConn{}
+       _ driver.ExecerContext  = &ctxOnlyConn{}
+)
+
+type ctxOnlyConn struct {
+       fc *fakeConn
+
+       queryCtxCalled bool
+       execCtxCalled  bool
+}
+
+func (c *ctxOnlyConn) Begin() (driver.Tx, error) {
+       return c.fc.Begin()
+}
+
+func (c *ctxOnlyConn) Close() error {
+       return c.fc.Close()
+}
+
+// Prepare is still part of the Conn interface, so while it isn't used
+// must be defined for compatibility.
+func (c *ctxOnlyConn) Prepare(q string) (driver.Stmt, error) {
+       panic("not used")
+}
+
+func (c *ctxOnlyConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
+       return c.fc.PrepareContext(ctx, q)
+}
+
+func (c *ctxOnlyConn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
+       c.queryCtxCalled = true
+       return c.fc.QueryContext(ctx, q, args)
+}
+
+func (c *ctxOnlyConn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
+       c.execCtxCalled = true
+       return c.fc.ExecContext(ctx, q, args)
+}
+
+// TestQueryExecContextOnly ensures drivers only need to implement QueryContext
+// and ExecContext methods.
+func TestQueryExecContextOnly(t *testing.T) {
+       // Ensure connection does not implment non-context interfaces.
+       var connType driver.Conn = &ctxOnlyConn{}
+       if _, ok := connType.(driver.Execer); ok {
+               t.Fatalf("%T must not implement driver.Execer", connType)
+       }
+       if _, ok := connType.(driver.Queryer); ok {
+               t.Fatalf("%T must not implement driver.Queryer", connType)
+       }
+
+       Register("ContextOnly", &ctxOnlyDriver{})
+       db, err := Open("ContextOnly", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer db.Close()
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       conn, err := db.Conn(ctx)
+       if err != nil {
+               t.Fatal("db.Conn", err)
+       }
+       defer conn.Close()
+
+       _, err = conn.ExecContext(ctx, "WIPE")
+       if err != nil {
+               t.Fatal("exec wipe", err)
+       }
+
+       _, err = conn.ExecContext(ctx, "CREATE|keys|v1=string")
+       if err != nil {
+               t.Fatal("exec create", err)
+       }
+       expectedValue := "value1"
+       _, err = conn.ExecContext(ctx, "INSERT|keys|v1=?", expectedValue)
+       if err != nil {
+               t.Fatal("exec insert", err)
+       }
+       rows, err := conn.QueryContext(ctx, "SELECT|keys|v1|")
+       if err != nil {
+               t.Fatal("query select", err)
+       }
+       v1 := ""
+       for rows.Next() {
+               err = rows.Scan(&v1)
+               if err != nil {
+                       t.Fatal("rows scan", err)
+               }
+       }
+       rows.Close()
+
+       if v1 != expectedValue {
+               t.Fatalf("expected %q, got %q", expectedValue, v1)
+       }
+
+       coc := conn.dc.ci.(*ctxOnlyConn)
+       if !coc.execCtxCalled {
+               t.Error("ExecContext not called")
+       }
+       if !coc.queryCtxCalled {
+               t.Error("QueryContext not called")
+       }
+}
+
 // badConn implements a bad driver.Conn, for TestBadDriver.
 // The Exec method panics.
 type badConn struct{}