]> Cypherpunks repositories - gostls13.git/commitdiff
exp/sql: close Rows on EOF
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 10 Jan 2012 20:51:27 +0000 (12:51 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 10 Jan 2012 20:51:27 +0000 (12:51 -0800)
Fixes #2624

R=rsc
CC=golang-dev
https://golang.org/cl/5530068

src/pkg/exp/sql/fakedb_test.go
src/pkg/exp/sql/sql.go
src/pkg/exp/sql/sql_test.go

index 2474a86f644c46fc5d5988fad918fd7bcf82c9b3..0883dd9f3edaa5707a52e169a180f8489503d1d7 100644 (file)
@@ -110,25 +110,34 @@ func init() {
 
 // Supports dsn forms:
 //    <dbname>
-//    <dbname>;wipe
+//    <dbname>;<opts>  (no currently supported options)
 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
-       d.mu.Lock()
-       defer d.mu.Unlock()
-       d.openCount++
-       if d.dbs == nil {
-               d.dbs = make(map[string]*fakeDB)
-       }
        parts := strings.Split(dsn, ";")
        if len(parts) < 1 {
                return nil, errors.New("fakedb: no database name")
        }
        name := parts[0]
+
+       db := d.getDB(name)
+
+       d.mu.Lock()
+       d.openCount++
+       d.mu.Unlock()
+       return &fakeConn{db: db}, nil
+}
+
+func (d *fakeDriver) getDB(name string) *fakeDB {
+       d.mu.Lock()
+       defer d.mu.Unlock()
+       if d.dbs == nil {
+               d.dbs = make(map[string]*fakeDB)
+       }
        db, ok := d.dbs[name]
        if !ok {
                db = &fakeDB{name: name}
                d.dbs[name] = db
        }
-       return &fakeConn{db: db}, nil
+       return db
 }
 
 func (db *fakeDB) wipe() {
index 937982cdbe6f9b3661ea963047f0efe9d92fd828..f53691b7c42dad179d5c1dbf333e7e40cb6c7a68 100644 (file)
@@ -549,8 +549,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
 // statement, a function to call to release the connection, and a
 // statement bound to that connection.
 func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
-       if s.stickyErr != nil {
-               return nil, nil, nil, s.stickyErr
+       if err = s.stickyErr; err != nil {
+               return
        }
        s.mu.Lock()
        if s.closed {
@@ -726,6 +726,9 @@ func (rs *Rows) Next() bool {
                rs.lastcols = make([]interface{}, len(rs.rowsi.Columns()))
        }
        rs.lasterr = rs.rowsi.Next(rs.lastcols)
+       if rs.lasterr == io.EOF {
+               rs.Close()
+       }
        return rs.lasterr == nil
 }
 
index 5307a235ddf13fc0eb84b09eb7da9e3358ed0f3a..590bf818fef3ee614de797d14ab6bdfe3eab43b2 100644 (file)
@@ -10,8 +10,10 @@ import (
        "testing"
 )
 
+const fakeDBName = "foo"
+
 func newTestDB(t *testing.T, name string) *DB {
-       db, err := Open("test", "foo")
+       db, err := Open("test", fakeDBName)
        if err != nil {
                t.Fatalf("Open: %v", err)
        }
@@ -73,6 +75,12 @@ func TestQuery(t *testing.T) {
        if !reflect.DeepEqual(got, want) {
                t.Logf(" got: %#v\nwant: %#v", got, want)
        }
+
+       // And verify that the final rows.Next() call, which hit EOF,
+       // also closed the rows connection.
+       if n := len(db.freeConn); n != 1 {
+               t.Errorf("free conns after query hitting EOF = %d; want 1", n)
+       }
 }
 
 func TestRowsColumns(t *testing.T) {