From: Gwenael Treguier Date: Sat, 10 Mar 2012 23:21:44 +0000 (-0800) Subject: database/sql: ensure Stmts are correctly closed. X-Git-Tag: weekly.2012-03-13~55 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=c3954dd5da820306dff6bbd73a7801d9739b038f;p=gostls13.git database/sql: ensure Stmts are correctly closed. To make sure that there is no resource leak, I suggest to fix the 'fakedb' driver such as it fails when any Stmt is not closed. First, add a check in fakeConn.Close(). Then, fix all missing Stmt.Close()/Rows.Close(). I am not sure that the strategy choose in fakeConn.Prepare/prepare* is ok. The weak point in this patch is the change in Tx.Query: - Tests pass without this change, - I found it by manually analyzing the code, - I just try to make Tx.Query look like DB.Query. R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/5759050 --- diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go index 8732d028bc..184e7756c5 100644 --- a/src/pkg/database/sql/fakedb_test.go +++ b/src/pkg/database/sql/fakedb_test.go @@ -214,6 +214,9 @@ func (c *fakeConn) Close() error { if c.db == nil { return errors.New("can't close fakeConn; already closed") } + if c.stmtsMade > c.stmtsClosed { + return errors.New("can't close; dangling statement(s)") + } c.db = nil return nil } @@ -250,6 +253,7 @@ func errf(msg string, args ...interface{}) error { // just a limitation for fakedb) func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { if len(parts) != 3 { + stmt.Close() return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) } stmt.table = parts[0] @@ -260,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e } nameVal := strings.Split(colspec, "=") if len(nameVal) != 2 { + stmt.Close() return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) } column, value := nameVal[0], nameVal[1] _, ok := c.db.columnType(stmt.table, column) if !ok { + stmt.Close() return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) } if value != "?" { + stmt.Close() return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", stmt.table, column) } @@ -280,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e // parts are table|col=type,col2=type2 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { if len(parts) != 2 { + stmt.Close() return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) } stmt.table = parts[0] for n, colspec := range strings.Split(parts[1], ",") { nameType := strings.Split(colspec, "=") if len(nameType) != 2 { + stmt.Close() return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) } stmt.colName = append(stmt.colName, nameType[0]) @@ -297,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e // parts are table|col=?,col2=val func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { if len(parts) != 2 { + stmt.Close() return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) } stmt.table = parts[0] for n, colspec := range strings.Split(parts[1], ",") { nameVal := strings.Split(colspec, "=") if len(nameVal) != 2 { + stmt.Close() return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) } column, value := nameVal[0], nameVal[1] ctype, ok := c.db.columnType(stmt.table, column) if !ok { + stmt.Close() return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) } stmt.colName = append(stmt.colName, column) @@ -323,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e case "int32": i, err := strconv.Atoi(value) if err != nil { + stmt.Close() return nil, errf("invalid conversion to int32 from %q", value) } subsetVal = int64(i) // int64 is a subset type, but not int32 default: + stmt.Close() return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) } stmt.colValue = append(stmt.colValue, subsetVal) @@ -362,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { case "INSERT": return c.prepareInsert(stmt, parts) default: + stmt.Close() return nil, errf("unsupported command type %q", cmd) } return stmt, nil diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go index c00425d8fa..51a357b37d 100644 --- a/src/pkg/database/sql/sql.go +++ b/src/pkg/database/sql/sql.go @@ -612,9 +612,11 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { return nil, err } rows, err := stmt.Query(args...) - if err == nil { - rows.closeStmt = stmt + if err != nil { + stmt.Close() + return nil, err } + rows.closeStmt = stmt return rows, err } @@ -1020,7 +1022,7 @@ func (r *Row) Scan(dest ...interface{}) error { } // TODO(bradfitz): for now we need to defensively clone all - // []byte that the driver returned (not permitting + // []byte that the driver returned (not permitting // *RawBytes in Rows.Scan), since we're about to close // the Rows in our defer, when we return from this function. // the contract with the driver.Next(...) interface is that it diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go index 90a40efa28..b296705865 100644 --- a/src/pkg/database/sql/sql_test.go +++ b/src/pkg/database/sql/sql_test.go @@ -251,6 +251,7 @@ func TestStatementQueryRow(t *testing.T) { if err != nil { t.Fatalf("Prepare: %v", err) } + defer stmt.Close() var age int for n, tt := range []struct { name string @@ -291,6 +292,7 @@ func TestExec(t *testing.T) { if err != nil { t.Errorf("Stmt, err = %v, %v", stmt, err) } + defer stmt.Close() type execTest struct { args []interface{} @@ -332,11 +334,14 @@ func TestTxStmt(t *testing.T) { if err != nil { t.Fatalf("Stmt, err = %v, %v", stmt, err) } + defer stmt.Close() tx, err := db.Begin() if err != nil { t.Fatalf("Begin = %v", err) } - _, err = tx.Stmt(stmt).Exec("Bobby", 7) + txs := tx.Stmt(stmt) + defer txs.Close() + _, err = txs.Exec("Bobby", 7) if err != nil { t.Fatalf("Exec = %v", err) } @@ -365,6 +370,7 @@ func TestTxQuery(t *testing.T) { if err != nil { t.Fatal(err) } + defer r.Close() if !r.Next() { if r.Err() != nil { @@ -561,6 +567,7 @@ func nullTestRun(t *testing.T, spec nullTestSpec) { if err != nil { t.Fatalf("prepare: %v", err) } + defer stmt.Close() if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil { t.Errorf("exec insert chris: %v", err) }