]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: ensure all driver interfaces are called under single lock
authorDaniel Theophanes <kardianos@gmail.com>
Tue, 17 Oct 2017 22:59:56 +0000 (15:59 -0700)
committerDaniel Theophanes <kardianos@gmail.com>
Wed, 25 Oct 2017 17:21:58 +0000 (17:21 +0000)
Russ pointed out in a previous CL golang.org/cl/65731 that not only
was the locking incomplete, previous changes did not correctly
lock driver calls in other sections. After inspecting
driverConn, driverStmt, driverResult, Tx, and Rows structs
where driver interfaces are stored, I discovered a few more places
that failed to lock driver calls. The largest of these
was the parameter type converter "driverArgs".

driverArgs was typically called right before another call to the
driver in a locked region, so I made the entire driverArgs expect
a locked driver mutex and combined the region. This should not
be a problem because the connection is pulled out of the connection
pool either way so there shouldn't be contention.

Fixes #21117

Change-Id: I88d46f74dca25fb11a30f0bf8e79785a73133d23
Reviewed-on: https://go-review.googlesource.com/71433
Run-TryBot: Daniel Theophanes <kardianos@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
src/database/sql/convert.go
src/database/sql/convert_test.go
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index c349a96edf0d8dbcfa893ed317856f74a357cc72..b44bed559d5a597b0644d5e3dc8de87512cd846a 100644 (file)
@@ -12,7 +12,6 @@ import (
        "fmt"
        "reflect"
        "strconv"
-       "sync"
        "time"
        "unicode"
        "unicode/utf8"
@@ -38,17 +37,10 @@ func validateNamedValueName(name string) error {
        return fmt.Errorf("name %q does not begin with a letter", name)
 }
 
-func driverNumInput(ds *driverStmt) int {
-       ds.Lock()
-       defer ds.Unlock() // in case NumInput panics
-       return ds.si.NumInput()
-}
-
 // ccChecker wraps the driver.ColumnConverter and allows it to be used
 // as if it were a NamedValueChecker. If the driver ColumnConverter
 // is not present then the NamedValueChecker will return driver.ErrSkip.
 type ccChecker struct {
-       sync.Locker
        cci  driver.ColumnConverter
        want int
 }
@@ -88,9 +80,7 @@ func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
        // same error.
        var err error
        arg := nv.Value
-       c.Lock()
        nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
-       c.Unlock()
        if err != nil {
                return err
        }
@@ -112,7 +102,7 @@ func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
 // Stmt.Query into driver Values.
 //
 // The statement ds may be nil, if no statement is available.
-func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
+func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
        nvargs := make([]driver.NamedValue, len(args))
 
        // -1 means the driver doesn't know how to count the number of
@@ -124,8 +114,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na
        var cc ccChecker
        if ds != nil {
                si = ds.si
-               want = driverNumInput(ds)
-               cc.Locker = ds.Locker
+               want = ds.si.NumInput()
                cc.want = want
        }
 
index 35dbab3339ebdf28a23e16f0f8cdf01f6206dcf3..c198177760ac32f1c6941269dfed76495fa482bd 100644 (file)
@@ -481,7 +481,7 @@ func TestDriverArgs(t *testing.T) {
        }
        for i, tt := range tests {
                ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
-               got, err := driverArgs(nil, ds, tt.args)
+               got, err := driverArgsConnLocked(nil, ds, tt.args)
                if err != nil {
                        t.Errorf("test[%d]: %v", i, err)
                        continue
index 070b7834539d52f02f7b070695871c9177acf763..31e22a7a74343ad60868f136cc5d62955b969811 100644 (file)
@@ -1005,6 +1005,7 @@ type rowsCursor struct {
 }
 
 func (rc *rowsCursor) touchMem() {
+       rc.parentMem.touchMem()
        rc.line++
 }
 
index c17b2b543b57202e89fa7c7b3474651ad7f5ee1b..be73b5e372920f0be92d2a0f0cd890faedfa8cd7 100644 (file)
@@ -1368,12 +1368,12 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
        }
        if ok {
                var nvdargs []driver.NamedValue
-               nvdargs, err = driverArgs(dc.ci, nil, args)
-               if err != nil {
-                       return nil, err
-               }
                var resi driver.Result
                withLock(dc, func() {
+                       nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+                       if err != nil {
+                               return
+                       }
                        resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
                })
                if err != driver.ErrSkip {
@@ -1439,13 +1439,14 @@ func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn fu
                queryer, ok = dc.ci.(driver.Queryer)
        }
        if ok {
-               nvdargs, err := driverArgs(dc.ci, nil, args)
-               if err != nil {
-                       releaseConn(err)
-                       return nil, err
-               }
+               var nvdargs []driver.NamedValue
                var rowsi driver.Rows
+               var err error
                withLock(dc, func() {
+                       nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+                       if err != nil {
+                               return
+                       }
                        rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
                })
                if err != driver.ErrSkip {
@@ -2034,11 +2035,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
                stmt.mu.Unlock()
 
                if si == nil {
-                       cs, err := stmt.prepareOnConnLocked(ctx, dc)
+                       withLock(dc, func() {
+                               var ds *driverStmt
+                               ds, err = stmt.prepareOnConnLocked(ctx, dc)
+                               si = ds.si
+                       })
                        if err != nil {
                                return &Stmt{stickyErr: err}
                        }
-                       si = cs.si
                }
                parentStmt = stmt
        }
@@ -2230,14 +2234,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
 }
 
 func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
-       dargs, err := driverArgs(ci, ds, args)
+       ds.Lock()
+       defer ds.Unlock()
+
+       dargs, err := driverArgsConnLocked(ci, ds, args)
        if err != nil {
                return nil, err
        }
 
-       ds.Lock()
-       defer ds.Unlock()
-
        resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
        if err != nil {
                return nil, err
@@ -2401,10 +2405,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
 }
 
 func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
-       var want int
-       withLock(ds, func() {
-               want = ds.si.NumInput()
-       })
+       ds.Lock()
+       defer ds.Unlock()
+
+       want := ds.si.NumInput()
 
        // -1 means the driver doesn't know how to count the number of
        // placeholders, so we won't sanity check input here and instead let the
@@ -2413,14 +2417,11 @@ func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, arg
                return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
        }
 
-       dargs, err := driverArgs(ci, ds, args)
+       dargs, err := driverArgsConnLocked(ci, ds, args)
        if err != nil {
                return nil, err
        }
 
-       ds.Lock()
-       defer ds.Unlock()
-
        rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
        if err != nil {
                return nil, err
@@ -2583,9 +2584,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
        if rs.closed {
                return false, false
        }
+
+       // Lock the driver connection before calling the driver interface
+       // rowsi to prevent a Tx from rolling back the connection at the same time.
+       rs.dc.Lock()
+       defer rs.dc.Unlock()
+
        if rs.lastcols == nil {
                rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
        }
+
        rs.lasterr = rs.rowsi.Next(rs.lastcols)
        if rs.lasterr != nil {
                // Close the connection if there is a driver error.
@@ -2635,6 +2643,12 @@ func (rs *Rows) NextResultSet() bool {
                doClose = true
                return false
        }
+
+       // Lock the driver connection before calling the driver interface
+       // rowsi to prevent a Tx from rolling back the connection at the same time.
+       rs.dc.Lock()
+       defer rs.dc.Unlock()
+
        rs.lasterr = nextResultSet.NextResultSet()
        if rs.lasterr != nil {
                doClose = true
@@ -2666,6 +2680,9 @@ func (rs *Rows) Columns() ([]string, error) {
        if rs.rowsi == nil {
                return nil, errors.New("sql: no Rows available")
        }
+       rs.dc.Lock()
+       defer rs.dc.Unlock()
+
        return rs.rowsi.Columns(), nil
 }
 
@@ -2680,7 +2697,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
        if rs.rowsi == nil {
                return nil, errors.New("sql: no Rows available")
        }
-       return rowsColumnInfoSetup(rs.rowsi), nil
+       rs.dc.Lock()
+       defer rs.dc.Unlock()
+
+       return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
 }
 
 // ColumnType contains the name and type of a column.
@@ -2741,7 +2761,7 @@ func (ci *ColumnType) DatabaseTypeName() string {
        return ci.databaseType
 }
 
-func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
+func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
        names := rowsi.Columns()
 
        list := make([]*ColumnType, len(names))
index dead273503f5ad64ce292579f403e77ea95f03de..f7b7d988e1b80b9e0c579e235c9945c405ab3cd2 100644 (file)
@@ -3157,6 +3157,9 @@ func TestIssue6081(t *testing.T) {
 // In the test, a context is canceled while the query is in process so
 // the internal rollback will run concurrently with the explicitly called
 // Tx.Rollback.
+//
+// The addition of calling rows.Next also tests
+// Issue 21117.
 func TestIssue18429(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)
@@ -3189,6 +3192,9 @@ func TestIssue18429(t *testing.T) {
                        // reported.
                        rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
                        if rows != nil {
+                               // Call Next to test Issue 21117 and check for races.
+                               for rows.Next() {
+                               }
                                rows.Close()
                        }
                        // This call will race with the context cancel rollback to complete