]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: correct level of write to same var for race detector
authorDaniel Theophanes <kardianos@gmail.com>
Mon, 12 Jun 2017 18:28:18 +0000 (11:28 -0700)
committerDaniel Theophanes <kardianos@gmail.com>
Mon, 12 Jun 2017 21:58:53 +0000 (21:58 +0000)
Rather then write to the same variable per fakeConn, write to either
fakeConn or rowsCursor.

Fixes #20646

Change-Id: Ifc79f989bd1606b8e3ebecb1e7844cce3ad06e17
Reviewed-on: https://go-review.googlesource.com/45393
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/database/sql/fakedb_test.go
src/database/sql/sql_test.go

index 6c8f81ac2acf4f4f9d368f7e93b9816902631c3d..4dcd096ca4d20db0e4fa4d8dc0f3a37ff32f3a50 100644 (file)
@@ -84,6 +84,11 @@ type row struct {
        cols []interface{} // must be same size as its table colname + coltype
 }
 
+type memToucher interface {
+       // touchMem reads & writes some memory, to help find data races.
+       touchMem()
+}
+
 type fakeConn struct {
        db *fakeDB // where to return ourselves to
 
@@ -104,6 +109,10 @@ type fakeConn struct {
        stickyBad bool
 }
 
+func (c *fakeConn) touchMem() {
+       c.line++
+}
+
 func (c *fakeConn) incrStat(v *int) {
        c.mu.Lock()
        *v++
@@ -121,6 +130,7 @@ type boundCol struct {
 }
 
 type fakeStmt struct {
+       memToucher
        c *fakeConn
        q string // just for debugging
 
@@ -303,7 +313,7 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
        if c.currTx != nil {
                return nil, errors.New("already in a transaction")
        }
-       c.line++
+       c.touchMem()
        c.currTx = &fakeTx{c: c}
        return c.currTx, nil
 }
@@ -345,7 +355,7 @@ func (c *fakeConn) Close() (err error) {
                        drv.mu.Unlock()
                }
        }()
-       c.line++
+       c.touchMem()
        if c.currTx != nil {
                return errors.New("can't close fakeConn; in a Transaction")
        }
@@ -533,14 +543,14 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
                return nil, driver.ErrBadConn
        }
 
-       c.line++
+       c.touchMem()
        var firstStmt, prev *fakeStmt
        for _, query := range strings.Split(query, ";") {
                parts := strings.Split(query, "|")
                if len(parts) < 1 {
                        return nil, errf("empty query")
                }
-               stmt := &fakeStmt{q: query, c: c}
+               stmt := &fakeStmt{q: query, c: c, memToucher: c}
                if firstStmt == nil {
                        firstStmt = stmt
                }
@@ -622,7 +632,7 @@ func (s *fakeStmt) Close() error {
        if s.c.db == nil {
                panic("in fakeStmt.Close, conn's db is nil (already closed)")
        }
-       s.c.line++
+       s.touchMem()
        if !s.closed {
                s.c.incrStat(&s.c.stmtsClosed)
                s.closed = true
@@ -657,7 +667,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
        if err != nil {
                return nil, err
        }
-       s.c.line++
+       s.touchMem()
 
        if s.wait > 0 {
                time.Sleep(s.wait)
@@ -770,7 +780,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
                return nil, err
        }
 
-       s.c.line++
+       s.touchMem()
        db := s.c.db
        if len(args) != s.placeholders {
                panic("error in pkg db; should only get here if size is correct")
@@ -866,12 +876,12 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
        }
 
        cursor := &rowsCursor{
-               c:       s.c,
-               posRow:  -1,
-               rows:    setMRows,
-               cols:    setColumns,
-               colType: setColType,
-               errPos:  -1,
+               parentMem: s.c,
+               posRow:    -1,
+               rows:      setMRows,
+               cols:      setColumns,
+               colType:   setColType,
+               errPos:    -1,
        }
        return cursor, nil
 }
@@ -891,7 +901,7 @@ func (tx *fakeTx) Commit() error {
        if hookCommitBadConn != nil && hookCommitBadConn() {
                return driver.ErrBadConn
        }
-       tx.c.line++
+       tx.c.touchMem()
        return nil
 }
 
@@ -903,18 +913,18 @@ func (tx *fakeTx) Rollback() error {
        if hookRollbackBadConn != nil && hookRollbackBadConn() {
                return driver.ErrBadConn
        }
-       tx.c.line++
+       tx.c.touchMem()
        return nil
 }
 
 type rowsCursor struct {
-       c       *fakeConn
-       cols    [][]string
-       colType [][]string
-       posSet  int
-       posRow  int
-       rows    [][]*row
-       closed  bool
+       parentMem memToucher
+       cols      [][]string
+       colType   [][]string
+       posSet    int
+       posRow    int
+       rows      [][]*row
+       closed    bool
 
        // errPos and err are for making Next return early with error.
        errPos int
@@ -924,6 +934,16 @@ type rowsCursor struct {
        // the original slice's first byte address.  we clone them
        // just so we're able to corrupt them on close.
        bytesClone map[*byte][]byte
+
+       // Every operation writes to line to enable the race detector
+       // check for data races.
+       // This is separate from the fakeConn.line to allow for drivers that
+       // can start multiple queries on the same transaction at the same time.
+       line int64
+}
+
+func (rc *rowsCursor) touchMem() {
+       rc.line++
 }
 
 func (rc *rowsCursor) Close() error {
@@ -932,7 +952,8 @@ func (rc *rowsCursor) Close() error {
                        bs[0] = 255 // first byte corrupted
                }
        }
-       rc.c.line++
+       rc.touchMem()
+       rc.parentMem.touchMem()
        rc.closed = true
        return nil
 }
@@ -955,7 +976,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
        if rc.closed {
                return errors.New("fakedb: cursor is closed")
        }
-       rc.c.line++
+       rc.touchMem()
        rc.posRow++
        if rc.posRow == rc.errPos {
                return rc.err
@@ -989,12 +1010,12 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
 }
 
 func (rc *rowsCursor) HasNextResultSet() bool {
-       rc.c.line++
+       rc.touchMem()
        return rc.posSet < len(rc.rows)-1
 }
 
 func (rc *rowsCursor) NextResultSet() error {
-       rc.c.line++
+       rc.touchMem()
        if rc.HasNextResultSet() {
                rc.posSet++
                rc.posRow = -1
index 9fb17df77e5026f74c7cbf5096b2a85540310538..7895aa04047dd5d1c150e0b54d5491607dcd6c92 100644 (file)
@@ -2995,9 +2995,8 @@ func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
                new(concurrentStmtExecTest),
                new(concurrentTxQueryTest),
                new(concurrentTxExecTest),
-               // golang.org/issue/20646
-               // new(concurrentTxStmtQueryTest),
-               // new(concurrentTxStmtExecTest),
+               new(concurrentTxStmtQueryTest),
+               new(concurrentTxStmtExecTest),
        }
        for _, ct := range c.tests {
                ct.init(t, db)
@@ -3243,9 +3242,8 @@ func TestConcurrency(t *testing.T) {
                {"StmtExec", new(concurrentStmtExecTest)},
                {"TxQuery", new(concurrentTxQueryTest)},
                {"TxExec", new(concurrentTxExecTest)},
-               // golang.org/issue/20646
-               // {"TxStmtQuery", new(concurrentTxStmtQueryTest)},
-               // {"TxStmtExec", new(concurrentTxStmtExecTest)},
+               {"TxStmtQuery", new(concurrentTxStmtQueryTest)},
+               {"TxStmtExec", new(concurrentTxStmtExecTest)},
                {"Random", new(concurrentRandomTest)},
        }
        for _, item := range list {
@@ -3582,7 +3580,6 @@ func BenchmarkConcurrentTxExec(b *testing.B) {
 }
 
 func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
-       b.Skip("golang.org/issue/20646")
        b.ReportAllocs()
        ct := new(concurrentTxStmtQueryTest)
        for i := 0; i < b.N; i++ {
@@ -3591,7 +3588,6 @@ func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
 }
 
 func BenchmarkConcurrentTxStmtExec(b *testing.B) {
-       b.Skip("golang.org/issue/20646")
        b.ReportAllocs()
        ct := new(concurrentTxStmtExecTest)
        for i := 0; i < b.N; i++ {