]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: ensure Stmts are correctly closed.
authorGwenael Treguier <gwenn.kahz@gmail.com>
Sat, 10 Mar 2012 23:21:44 +0000 (15:21 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 10 Mar 2012 23:21:44 +0000 (15:21 -0800)
To make sure that there is no resource leak,
I suggest to fix the 'fakedb' driver such as it fails when any
Stmt is not closed.
First, add a check in fakeConn.Close().
Then, fix all missing Stmt.Close()/Rows.Close().
I am not sure that the strategy choose in fakeConn.Prepare/prepare* is ok.
The weak point in this patch is the change in Tx.Query:
  - Tests pass without this change,
  - I found it by manually analyzing the code,
  - I just try to make Tx.Query look like DB.Query.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/5759050

src/pkg/database/sql/fakedb_test.go
src/pkg/database/sql/sql.go
src/pkg/database/sql/sql_test.go

index 8732d028bc56d36022f01c38ca850fea1721639e..184e7756c51e4228e89e420393519bff05361a81 100644 (file)
@@ -214,6 +214,9 @@ func (c *fakeConn) Close() error {
        if c.db == nil {
                return errors.New("can't close fakeConn; already closed")
        }
+       if c.stmtsMade > c.stmtsClosed {
+               return errors.New("can't close; dangling statement(s)")
+       }
        c.db = nil
        return nil
 }
@@ -250,6 +253,7 @@ func errf(msg string, args ...interface{}) error {
 //  just a limitation for fakedb)
 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 3 {
+               stmt.Close()
                return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
        }
        stmt.table = parts[0]
@@ -260,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
                }
                nameVal := strings.Split(colspec, "=")
                if len(nameVal) != 2 {
+                       stmt.Close()
                        return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
                }
                column, value := nameVal[0], nameVal[1]
                _, ok := c.db.columnType(stmt.table, column)
                if !ok {
+                       stmt.Close()
                        return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
                }
                if value != "?" {
+                       stmt.Close()
                        return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
                                stmt.table, column)
                }
@@ -280,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
 // parts are table|col=type,col2=type2
 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 2 {
+               stmt.Close()
                return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
        }
        stmt.table = parts[0]
        for n, colspec := range strings.Split(parts[1], ",") {
                nameType := strings.Split(colspec, "=")
                if len(nameType) != 2 {
+                       stmt.Close()
                        return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
                }
                stmt.colName = append(stmt.colName, nameType[0])
@@ -297,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
 // parts are table|col=?,col2=val
 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
        if len(parts) != 2 {
+               stmt.Close()
                return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
        }
        stmt.table = parts[0]
        for n, colspec := range strings.Split(parts[1], ",") {
                nameVal := strings.Split(colspec, "=")
                if len(nameVal) != 2 {
+                       stmt.Close()
                        return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
                }
                column, value := nameVal[0], nameVal[1]
                ctype, ok := c.db.columnType(stmt.table, column)
                if !ok {
+                       stmt.Close()
                        return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
                }
                stmt.colName = append(stmt.colName, column)
@@ -323,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
                        case "int32":
                                i, err := strconv.Atoi(value)
                                if err != nil {
+                                       stmt.Close()
                                        return nil, errf("invalid conversion to int32 from %q", value)
                                }
                                subsetVal = int64(i) // int64 is a subset type, but not int32
                        default:
+                               stmt.Close()
                                return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
                        }
                        stmt.colValue = append(stmt.colValue, subsetVal)
@@ -362,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
        case "INSERT":
                return c.prepareInsert(stmt, parts)
        default:
+               stmt.Close()
                return nil, errf("unsupported command type %q", cmd)
        }
        return stmt, nil
index c00425d8fa1e25e2c353cf0bb106c235bfd7caf7..51a357b37debbc2cd6cefe176a6472686699a5c8 100644 (file)
@@ -612,9 +612,11 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
                return nil, err
        }
        rows, err := stmt.Query(args...)
-       if err == nil {
-               rows.closeStmt = stmt
+       if err != nil {
+               stmt.Close()
+               return nil, err
        }
+       rows.closeStmt = stmt
        return rows, err
 }
 
@@ -1020,7 +1022,7 @@ func (r *Row) Scan(dest ...interface{}) error {
        }
 
        // TODO(bradfitz): for now we need to defensively clone all
-       // []byte that the driver returned (not permitting 
+       // []byte that the driver returned (not permitting
        // *RawBytes in Rows.Scan), since we're about to close
        // the Rows in our defer, when we return from this function.
        // the contract with the driver.Next(...) interface is that it
index 90a40efa2861477f3e0aa6675011316d4372e546..b296705865f10c935efe541ae771a41de0beb8a1 100644 (file)
@@ -251,6 +251,7 @@ func TestStatementQueryRow(t *testing.T) {
        if err != nil {
                t.Fatalf("Prepare: %v", err)
        }
+       defer stmt.Close()
        var age int
        for n, tt := range []struct {
                name string
@@ -291,6 +292,7 @@ func TestExec(t *testing.T) {
        if err != nil {
                t.Errorf("Stmt, err = %v, %v", stmt, err)
        }
+       defer stmt.Close()
 
        type execTest struct {
                args    []interface{}
@@ -332,11 +334,14 @@ func TestTxStmt(t *testing.T) {
        if err != nil {
                t.Fatalf("Stmt, err = %v, %v", stmt, err)
        }
+       defer stmt.Close()
        tx, err := db.Begin()
        if err != nil {
                t.Fatalf("Begin = %v", err)
        }
-       _, err = tx.Stmt(stmt).Exec("Bobby", 7)
+       txs := tx.Stmt(stmt)
+       defer txs.Close()
+       _, err = txs.Exec("Bobby", 7)
        if err != nil {
                t.Fatalf("Exec = %v", err)
        }
@@ -365,6 +370,7 @@ func TestTxQuery(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
+       defer r.Close()
 
        if !r.Next() {
                if r.Err() != nil {
@@ -561,6 +567,7 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
        if err != nil {
                t.Fatalf("prepare: %v", err)
        }
+       defer stmt.Close()
        if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
                t.Errorf("exec insert chris: %v", err)
        }