From: Russ Cox Date: Fri, 18 Dec 2015 17:16:05 +0000 (-0500) Subject: database/sql: guard against panics in driver.Stmt implementation X-Git-Tag: go1.6beta2~71 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=99ed71a02c07a073a82a06c8f3a975a9540e962e;p=gostls13.git database/sql: guard against panics in driver.Stmt implementation For #13677, but there is more to do. Change-Id: Id1af999dc972d07cdfc771e5855a1a7dca47ca96 Reviewed-on: https://go-review.googlesource.com/18046 Reviewed-by: Brad Fitzpatrick --- diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index f1e8f6cb6e..b5ff121358 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -33,6 +33,9 @@ var _ = log.Printf // INSERT||col=val,col2=val2,col3=? // SELECT||projectcol1,projectcol2|filtercol=?,filtercol2=? // +// Any of these can be preceded by PANIC||, to cause the +// named method on fakeStmt to panic. +// // When opening a fakeDriver's database, it starts empty with no // tables. All tables and data are stored in memory only. type fakeDriver struct { @@ -111,6 +114,7 @@ type fakeStmt struct { cmd string table string + panic string closed bool @@ -499,9 +503,15 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { if len(parts) < 1 { return nil, errf("empty query") } + stmt := &fakeStmt{q: query, c: c} + if len(parts) >= 3 && parts[0] == "PANIC" { + stmt.panic = parts[1] + parts = parts[2:] + } cmd := parts[0] + stmt.cmd = cmd parts = parts[1:] - stmt := &fakeStmt{q: query, c: c, cmd: cmd} + c.incrStat(&c.stmtsMade) switch cmd { case "WIPE": @@ -524,6 +534,9 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { } func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { + if s.panic == "ColumnConverter" { + panic(s.panic) + } if len(s.placeholderConverter) == 0 { return driver.DefaultParameterConverter } @@ -531,6 +544,9 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { } func (s *fakeStmt) Close() error { + if s.panic == "Close" { + panic(s.panic) + } if s.c == nil { panic("nil conn in fakeStmt.Close") } @@ -550,6 +566,9 @@ var errClosed = errors.New("fakedb: statement has been closed") var hookExecBadConn func() bool func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { + if s.panic == "Exec" { + panic(s.panic) + } if s.closed { return nil, errClosed } @@ -634,6 +653,9 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result var hookQueryBadConn func() bool func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { + if s.panic == "Query" { + panic(s.panic) + } if s.closed { return nil, errClosed } @@ -716,6 +738,9 @@ rows: } func (s *fakeStmt) NumInput() int { + if s.panic == "NumInput" { + panic(s.panic) + } return s.placeholders } diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 2d1528a21f..b5a17f0fc1 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1477,10 +1477,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return nil, driver.ErrBadConn } -func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { +func driverNumInput(ds driverStmt) int { ds.Lock() - want := ds.si.NumInput() - ds.Unlock() + defer ds.Unlock() // in case NumInput panics + return ds.si.NumInput() +} + +func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { + want := driverNumInput(ds) // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the @@ -1495,8 +1499,8 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { } ds.Lock() + defer ds.Unlock() resi, err := ds.si.Exec(dargs) - ds.Unlock() if err != nil { return nil, err } @@ -1927,6 +1931,6 @@ func stack() string { // withLock runs while holding lk. func withLock(lk sync.Locker, fn func()) { lk.Lock() + defer lk.Unlock() // in case fn panics fn() - lk.Unlock() } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 48c872d8c6..8ec70d99b0 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -68,6 +68,46 @@ func newTestDB(t testing.TB, name string) *DB { return db } +func TestDriverPanic(t *testing.T) { + // Test that if driver panics, database/sql does not deadlock. + db, err := Open("test", fakeDBName) + if err != nil { + t.Fatalf("Open: %v", err) + } + expectPanic := func(name string, f func()) { + defer func() { + err := recover() + if err == nil { + t.Fatalf("%s did not panic", name) + } + }() + f() + } + + expectPanic("Exec Exec", func() { db.Exec("PANIC|Exec|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + expectPanic("Exec NumInput", func() { db.Exec("PANIC|NumInput|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + expectPanic("Exec Close", func() { db.Exec("PANIC|Close|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + exec(t, db, "PANIC|Query|WIPE") // should run successfully: Exec does not call Query + exec(t, db, "WIPE") // check not deadlocked + + exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime") + + expectPanic("Query Query", func() { db.Query("PANIC|Query|SELECT|people|age,name|") }) + expectPanic("Query NumInput", func() { db.Query("PANIC|NumInput|SELECT|people|age,name|") }) + expectPanic("Query Close", func() { + rows, err := db.Query("PANIC|Close|SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + rows.Close() + }) + db.Query("PANIC|Exec|SELECT|people|age,name|") // should run successfully: Query does not call Exec + exec(t, db, "WIPE") // check not deadlocked +} + func exec(t testing.TB, db *DB, query string, args ...interface{}) { _, err := db.Exec(query, args...) if err != nil {