]> Cypherpunks repositories - gostls13.git/commitdiff
sql: add DB.Close, fix bugs, remove Execer on Driver (only Conn)
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 14 Nov 2011 18:48:26 +0000 (10:48 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 14 Nov 2011 18:48:26 +0000 (10:48 -0800)
R=rsc
CC=golang-dev
https://golang.org/cl/5372099

src/pkg/exp/sql/convert.go
src/pkg/exp/sql/driver/driver.go
src/pkg/exp/sql/fakedb_test.go
src/pkg/exp/sql/sql.go
src/pkg/exp/sql/sql_test.go

index e46cebe9a3da51f5e8901b41dfad61222209e52e..48e281203bead02bd2f197a330c9a0528bd81090 100644 (file)
@@ -14,6 +14,21 @@ import (
        "strconv"
 )
 
+// subsetTypeArgs takes a slice of arguments from callers of the sql
+// package and converts them into a slice of the driver package's
+// "subset types".
+func subsetTypeArgs(args []interface{}) ([]interface{}, error) {
+       out := make([]interface{}, len(args))
+       for n, arg := range args {
+               var err error
+               out[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
+               if err != nil {
+                       return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err)
+               }
+       }
+       return out, nil
+}
+
 // convertAssign copies to dest the value in src, converting it if possible.
 // An error is returned if the copy would result in loss of information.
 // dest should be a pointer type.
index 6a51c342415a33bec30bf2e4ba4ee7e5e7a51432..35fc6ae43c1dd0cef837a40323ae757936efc698 100644 (file)
@@ -36,19 +36,22 @@ type Driver interface {
        Open(name string) (Conn, error)
 }
 
-// Execer is an optional interface that may be implemented by a Driver
-// or a Conn.
-//
-// If a Driver does not implement Execer, the sql package's DB.Exec
-// method first obtains a free connection from its free pool or from
-// the driver's Open method. Execer should only be implemented by
-// drivers that can provide a more efficient implementation.
+// ErrSkip may be returned by some optional interfaces' methods to
+// indicate at runtime that the fast path is unavailable and the sql
+// package should continue as if the optional interface was not
+// implemented. ErrSkip is only supported where explicitly
+// documented.
+var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
+
+// Execer is an optional interface that may be implemented by a Conn.
 //
 // If a Conn does not implement Execer, the db package's DB.Exec will
 // first prepare a query, execute the statement, and then close the
 // statement.
 //
 // All arguments are of a subset type as defined in the package docs.
+//
+// Exec may return ErrSkip.
 type Execer interface {
        Exec(query string, args []interface{}) (Result, error)
 }
index c8a19974d641df1317c832829ba7929a7a71df9c..17028e2cc388401fec45cecff753b79a7f15403d 100644 (file)
@@ -195,6 +195,29 @@ func (c *fakeConn) Close() error {
        return nil
 }
 
+func checkSubsetTypes(args []interface{}) error {
+       for n, arg := range args {
+               switch arg.(type) {
+               case int64, float64, bool, nil, []byte, string:
+               default:
+                       return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
+               }
+       }
+       return nil
+}
+
+func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) {
+       // This is an optional interface, but it's implemented here
+       // just to check that all the args of of the proper types.
+       // ErrSkip is returned so the caller acts as if we didn't
+       // implement this at all.
+       err := checkSubsetTypes(args)
+       if err != nil {
+               return nil, err
+       }
+       return nil, driver.ErrSkip
+}
+
 func errf(msg string, args ...interface{}) error {
        return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
 }
@@ -323,6 +346,11 @@ func (s *fakeStmt) Close() error {
 }
 
 func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
+       err := checkSubsetTypes(args)
+       if err != nil {
+               return nil, err
+       }
+
        db := s.c.db
        switch s.cmd {
        case "WIPE":
@@ -377,6 +405,11 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
 }
 
 func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
+       err := checkSubsetTypes(args)
+       if err != nil {
+               return nil, err
+       }
+
        db := s.c.db
        if len(args) != s.placeholders {
                panic("error in pkg db; should only get here if size is correct")
index 291af7f67dcb0878f8c1fce8322daa9f0ea2619d..d3677afb3bae144ba4d20ea86eb91fae7885a993 100644 (file)
@@ -88,8 +88,9 @@ type DB struct {
        driver driver.Driver
        dsn    string
 
-       mu       sync.Mutex
+       mu       sync.Mutex // protects freeConn and closed
        freeConn []driver.Conn
+       closed   bool
 }
 
 // Open opens a database specified by its database driver name and a
@@ -106,6 +107,22 @@ func Open(driverName, dataSourceName string) (*DB, error) {
        return &DB{driver: driver, dsn: dataSourceName}, nil
 }
 
+// Close closes the database, releasing any open resources.
+func (db *DB) Close() error {
+       db.mu.Lock()
+       defer db.mu.Unlock()
+       var err error
+       for _, c := range db.freeConn {
+               err1 := c.Close()
+               if err1 != nil {
+                       err = err1
+               }
+       }
+       db.freeConn = nil
+       db.closed = true
+       return err
+}
+
 func (db *DB) maxIdleConns() int {
        const defaultMaxIdleConns = 2
        // TODO(bradfitz): ask driver, if supported, for its default preference
@@ -116,6 +133,9 @@ func (db *DB) maxIdleConns() int {
 // conn returns a newly-opened or cached driver.Conn
 func (db *DB) conn() (driver.Conn, error) {
        db.mu.Lock()
+       if db.closed {
+               return nil, errors.New("sql: database is closed")
+       }
        if n := len(db.freeConn); n > 0 {
                conn := db.freeConn[n-1]
                db.freeConn = db.freeConn[:n-1]
@@ -140,11 +160,13 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
 }
 
 func (db *DB) putConn(c driver.Conn) {
-       if n := len(db.freeConn); n < db.maxIdleConns() {
+       db.mu.Lock()
+       defer db.mu.Unlock()
+       if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
                db.freeConn = append(db.freeConn, c)
                return
        }
-       db.closeConn(c)
+       db.closeConn(c) // TODO(bradfitz): release lock before calling this?
 }
 
 func (db *DB) closeConn(c driver.Conn) {
@@ -180,17 +202,11 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
 
 // Exec executes a query without returning any rows.
 func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
-       // Optional fast path, if the driver implements driver.Execer.
-       if execer, ok := db.driver.(driver.Execer); ok {
-               resi, err := execer.Exec(query, args)
-               if err != nil {
-                       return nil, err
-               }
-               return result{resi}, nil
+       sargs, err := subsetTypeArgs(args)
+       if err != nil {
+               return nil, err
        }
 
-       // If the driver does not implement driver.Execer, we need
-       // a connection.
        ci, err := db.conn()
        if err != nil {
                return nil, err
@@ -198,11 +214,13 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
        defer db.putConn(ci)
 
        if execer, ok := ci.(driver.Execer); ok {
-               resi, err := execer.Exec(query, args)
-               if err != nil {
-                       return nil, err
+               resi, err := execer.Exec(query, sargs)
+               if err != driver.ErrSkip {
+                       if err != nil {
+                               return nil, err
+                       }
+                       return result{resi}, nil
                }
-               return result{resi}, nil
        }
 
        sti, err := ci.Prepare(query)
@@ -210,7 +228,8 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
                return nil, err
        }
        defer sti.Close()
-       resi, err := sti.Exec(args)
+
+       resi, err := sti.Exec(sargs)
        if err != nil {
                return nil, err
        }
@@ -386,7 +405,13 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
                return nil, err
        }
        defer sti.Close()
-       resi, err := sti.Exec(args)
+
+       sargs, err := subsetTypeArgs(args)
+       if err != nil {
+               return nil, err
+       }
+
+       resi, err := sti.Exec(sargs)
        if err != nil {
                return nil, err
        }
@@ -548,7 +573,11 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
        if len(args) != si.NumInput() {
                return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args))
        }
-       rowsi, err := si.Query(args)
+       sargs, err := subsetTypeArgs(args)
+       if err != nil {
+               return nil, err
+       }
+       rowsi, err := si.Query(sargs)
        if err != nil {
                s.db.putConn(ci)
                return nil, err
index eb1bb58966eae95ee8f8011dcf1d4669433f5681..d365f6ba190970d0f0124a8ea0208f787a6c0844 100644 (file)
@@ -34,8 +34,16 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) {
        }
 }
 
+func closeDB(t *testing.T, db *DB) {
+       err := db.Close()
+       if err != nil {
+               t.Fatalf("error closing DB: %v", err)
+       }
+}
+
 func TestQuery(t *testing.T) {
        db := newTestDB(t, "people")
+       defer closeDB(t, db)
        var name string
        var age int
 
@@ -69,6 +77,7 @@ func TestQuery(t *testing.T) {
 
 func TestStatementQueryRow(t *testing.T) {
        db := newTestDB(t, "people")
+       defer closeDB(t, db)
        stmt, err := db.Prepare("SELECT|people|age|name=?")
        if err != nil {
                t.Fatalf("Prepare: %v", err)
@@ -94,6 +103,7 @@ func TestStatementQueryRow(t *testing.T) {
 // just a test of fakedb itself
 func TestBogusPreboundParameters(t *testing.T) {
        db := newTestDB(t, "foo")
+       defer closeDB(t, db)
        exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
        _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
        if err == nil {
@@ -106,6 +116,7 @@ func TestBogusPreboundParameters(t *testing.T) {
 
 func TestDb(t *testing.T) {
        db := newTestDB(t, "foo")
+       defer closeDB(t, db)
        exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
        stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
        if err != nil {