]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: don't hang if the driver Exec method panics
authorIan Lance Taylor <iant@golang.org>
Tue, 31 May 2016 13:58:33 +0000 (06:58 -0700)
committerIan Lance Taylor <iant@golang.org>
Mon, 29 Aug 2016 16:51:56 +0000 (16:51 +0000)
Fixes #13677.
Fixes #15901.

Change-Id: Idffb82cdcba4985954d061bdb021217f47ff4984
Reviewed-on: https://go-review.googlesource.com/23576
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/database/sql/sql.go
src/database/sql/sql_test.go

index 09de1c34e826bdb7468dcd26c5d0393e3864f299..9d8afb01b0f92ada36a6b598fbed926df6c24f19 100644 (file)
@@ -983,9 +983,10 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
        if err != nil {
                return nil, err
        }
-       dc.Lock()
-       si, err := dc.prepareLocked(query)
-       dc.Unlock()
+       var si driver.Stmt
+       withLock(dc, func() {
+               si, err = dc.prepareLocked(query)
+       })
        if err != nil {
                db.putConn(dc, err)
                return nil, err
@@ -1028,13 +1029,15 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
        }()
 
        if execer, ok := dc.ci.(driver.Execer); ok {
-               dargs, err := driverArgs(nil, args)
+               var dargs []driver.Value
+               dargs, err = driverArgs(nil, args)
                if err != nil {
                        return nil, err
                }
-               dc.Lock()
-               resi, err := execer.Exec(query, dargs)
-               dc.Unlock()
+               var resi driver.Result
+               withLock(dc, func() {
+                       resi, err = execer.Exec(query, dargs)
+               })
                if err != driver.ErrSkip {
                        if err != nil {
                                return nil, err
@@ -1043,9 +1046,10 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
                }
        }
 
-       dc.Lock()
-       si, err := dc.ci.Prepare(query)
-       dc.Unlock()
+       var si driver.Stmt
+       withLock(dc, func() {
+               si, err = dc.ci.Prepare(query)
+       })
        if err != nil {
                return nil, err
        }
@@ -1088,9 +1092,10 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
                        releaseConn(err)
                        return nil, err
                }
-               dc.Lock()
-               rowsi, err := queryer.Query(query, dargs)
-               dc.Unlock()
+               var rowsi driver.Rows
+               withLock(dc, func() {
+                       rowsi, err = queryer.Query(query, dargs)
+               })
                if err != driver.ErrSkip {
                        if err != nil {
                                releaseConn(err)
@@ -1107,9 +1112,11 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
                }
        }
 
-       dc.Lock()
-       si, err := dc.ci.Prepare(query)
-       dc.Unlock()
+       var si driver.Stmt
+       var err error
+       withLock(dc, func() {
+               si, err = dc.ci.Prepare(query)
+       })
        if err != nil {
                releaseConn(err)
                return nil, err
@@ -1118,9 +1125,9 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
        ds := driverStmt{dc, si}
        rowsi, err := rowsiFromStatement(ds, args...)
        if err != nil {
-               dc.Lock()
-               si.Close()
-               dc.Unlock()
+               withLock(dc, func() {
+                       si.Close()
+               })
                releaseConn(err)
                return nil, err
        }
@@ -1166,9 +1173,10 @@ func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) {
        if err != nil {
                return nil, err
        }
-       dc.Lock()
-       txi, err := dc.ci.Begin()
-       dc.Unlock()
+       var txi driver.Tx
+       withLock(dc, func() {
+               txi, err = dc.ci.Begin()
+       })
        if err != nil {
                db.putConn(dc, err)
                return nil, err
@@ -1238,10 +1246,10 @@ func (tx *Tx) grabConn() (*driverConn, error) {
 // Closes all Stmts prepared for this transaction.
 func (tx *Tx) closePrepared() {
        tx.stmts.Lock()
+       defer tx.stmts.Unlock()
        for _, stmt := range tx.stmts.v {
                stmt.Close()
        }
-       tx.stmts.Unlock()
 }
 
 // Commit commits the transaction.
@@ -1249,9 +1257,10 @@ func (tx *Tx) Commit() error {
        if tx.done {
                return ErrTxDone
        }
-       tx.dc.Lock()
-       err := tx.txi.Commit()
-       tx.dc.Unlock()
+       var err error
+       withLock(tx.dc, func() {
+               err = tx.txi.Commit()
+       })
        if err != driver.ErrBadConn {
                tx.closePrepared()
        }
@@ -1264,9 +1273,10 @@ func (tx *Tx) Rollback() error {
        if tx.done {
                return ErrTxDone
        }
-       tx.dc.Lock()
-       err := tx.txi.Rollback()
-       tx.dc.Unlock()
+       var err error
+       withLock(tx.dc, func() {
+               err = tx.txi.Rollback()
+       })
        if err != driver.ErrBadConn {
                tx.closePrepared()
        }
@@ -1299,9 +1309,10 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
                return nil, err
        }
 
-       dc.Lock()
-       si, err := dc.ci.Prepare(query)
-       dc.Unlock()
+       var si driver.Stmt
+       withLock(dc, func() {
+               si, err = dc.ci.Prepare(query)
+       })
        if err != nil {
                return nil, err
        }
@@ -1346,9 +1357,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
        if err != nil {
                return &Stmt{stickyErr: err}
        }
-       dc.Lock()
-       si, err := dc.ci.Prepare(stmt.query)
-       dc.Unlock()
+       var si driver.Stmt
+       withLock(dc, func() {
+               si, err = dc.ci.Prepare(stmt.query)
+       })
        txs := &Stmt{
                db: tx.db,
                tx: tx,
@@ -1378,9 +1390,10 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
                if err != nil {
                        return nil, err
                }
-               dc.Lock()
-               resi, err := execer.Exec(query, dargs)
-               dc.Unlock()
+               var resi driver.Result
+               withLock(dc, func() {
+                       resi, err = execer.Exec(query, dargs)
+               })
                if err == nil {
                        return driverResult{dc, resi}, nil
                }
@@ -1389,9 +1402,10 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
                }
        }
 
-       dc.Lock()
-       si, err := dc.ci.Prepare(query)
-       dc.Unlock()
+       var si driver.Stmt
+       withLock(dc, func() {
+               si, err = dc.ci.Prepare(query)
+       })
        if err != nil {
                return nil, err
        }
@@ -1578,9 +1592,9 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
        s.mu.Unlock()
 
        // No luck; we need to prepare the statement on this connection
-       dc.Lock()
-       si, err = dc.prepareLocked(s.query)
-       dc.Unlock()
+       withLock(dc, func() {
+               si, err = dc.prepareLocked(s.query)
+       })
        if err != nil {
                s.db.putConn(dc, err)
                return nil, nil, nil, err
@@ -1635,9 +1649,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
 }
 
 func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
-       ds.Lock()
-       want := ds.si.NumInput()
-       ds.Unlock()
+       var want int
+       withLock(ds, func() {
+               want = ds.si.NumInput()
+       })
 
        // -1 means the driver doesn't know how to count the number of
        // placeholders, so we won't sanity check input here and instead let the
@@ -1652,8 +1667,8 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error)
        }
 
        ds.Lock()
+       defer ds.Unlock()
        rowsi, err := ds.si.Query(dargs)
-       ds.Unlock()
        if err != nil {
                return nil, err
        }
@@ -1695,9 +1710,8 @@ func (s *Stmt) Close() error {
        s.closed = true
 
        if s.tx != nil {
-               err := s.txsi.Close()
-               s.mu.Unlock()
-               return err
+               defer s.mu.Unlock()
+               return s.txsi.Close()
        }
        s.mu.Unlock()
 
index 08df0c7666a357ae173a5d3b5918d3bcb1c9862a..41afd00e92e07d9d9114f22c46bb9a8d6fbbfe64 100644 (file)
@@ -2299,6 +2299,53 @@ func TestConnectionLeak(t *testing.T) {
        wg.Wait()
 }
 
+// badConn implements a bad driver.Conn, for TestBadDriver.
+// The Exec method panics.
+type badConn struct{}
+
+func (bc badConn) Prepare(query string) (driver.Stmt, error) {
+       return nil, errors.New("badConn Prepare")
+}
+
+func (bc badConn) Close() error {
+       return nil
+}
+
+func (bc badConn) Begin() (driver.Tx, error) {
+       return nil, errors.New("badConn Begin")
+}
+
+func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+       panic("badConn.Exec")
+}
+
+// badDriver is a driver.Driver that uses badConn.
+type badDriver struct{}
+
+func (bd badDriver) Open(name string) (driver.Conn, error) {
+       return badConn{}, nil
+}
+
+// Issue 15901.
+func TestBadDriver(t *testing.T) {
+       Register("bad", badDriver{})
+       db, err := Open("bad", "ignored")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer func() {
+               if r := recover(); r == nil {
+                       t.Error("expected panic")
+               } else {
+                       if want := "badConn.Exec"; r.(string) != want {
+                               t.Errorf("panic was %v, expected %v", r, want)
+                       }
+               }
+       }()
+       defer db.Close()
+       db.Exec("ignored")
+}
+
 func BenchmarkConcurrentDBExec(b *testing.B) {
        b.ReportAllocs()
        ct := new(concurrentDBExecTest)