From 0a8005c7729951e26a37d17bb42a989f30bb415d Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 14 Nov 2011 10:48:26 -0800 Subject: [PATCH] sql: add DB.Close, fix bugs, remove Execer on Driver (only Conn) R=rsc CC=golang-dev https://golang.org/cl/5372099 --- src/pkg/exp/sql/convert.go | 15 +++++++ src/pkg/exp/sql/driver/driver.go | 17 ++++---- src/pkg/exp/sql/fakedb_test.go | 33 ++++++++++++++++ src/pkg/exp/sql/sql.go | 67 +++++++++++++++++++++++--------- src/pkg/exp/sql/sql_test.go | 11 ++++++ 5 files changed, 117 insertions(+), 26 deletions(-) diff --git a/src/pkg/exp/sql/convert.go b/src/pkg/exp/sql/convert.go index e46cebe9a3..48e281203b 100644 --- a/src/pkg/exp/sql/convert.go +++ b/src/pkg/exp/sql/convert.go @@ -14,6 +14,21 @@ import ( "strconv" ) +// subsetTypeArgs takes a slice of arguments from callers of the sql +// package and converts them into a slice of the driver package's +// "subset types". +func subsetTypeArgs(args []interface{}) ([]interface{}, error) { + out := make([]interface{}, len(args)) + for n, arg := range args { + var err error + out[n], err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { + return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err) + } + } + return out, nil +} + // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. diff --git a/src/pkg/exp/sql/driver/driver.go b/src/pkg/exp/sql/driver/driver.go index 6a51c34241..35fc6ae43c 100644 --- a/src/pkg/exp/sql/driver/driver.go +++ b/src/pkg/exp/sql/driver/driver.go @@ -36,19 +36,22 @@ type Driver interface { Open(name string) (Conn, error) } -// Execer is an optional interface that may be implemented by a Driver -// or a Conn. -// -// If a Driver does not implement Execer, the sql package's DB.Exec -// method first obtains a free connection from its free pool or from -// the driver's Open method. Execer should only be implemented by -// drivers that can provide a more efficient implementation. +// ErrSkip may be returned by some optional interfaces' methods to +// indicate at runtime that the fast path is unavailable and the sql +// package should continue as if the optional interface was not +// implemented. ErrSkip is only supported where explicitly +// documented. +var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented") + +// Execer is an optional interface that may be implemented by a Conn. // // If a Conn does not implement Execer, the db package's DB.Exec will // first prepare a query, execute the statement, and then close the // statement. // // All arguments are of a subset type as defined in the package docs. +// +// Exec may return ErrSkip. type Execer interface { Exec(query string, args []interface{}) (Result, error) } diff --git a/src/pkg/exp/sql/fakedb_test.go b/src/pkg/exp/sql/fakedb_test.go index c8a19974d6..17028e2cc3 100644 --- a/src/pkg/exp/sql/fakedb_test.go +++ b/src/pkg/exp/sql/fakedb_test.go @@ -195,6 +195,29 @@ func (c *fakeConn) Close() error { return nil } +func checkSubsetTypes(args []interface{}) error { + for n, arg := range args { + switch arg.(type) { + case int64, float64, bool, nil, []byte, string: + default: + return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) + } + } + return nil +} + +func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args of of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + func errf(msg string, args ...interface{}) error { return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) } @@ -323,6 +346,11 @@ func (s *fakeStmt) Close() error { } func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) { + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + db := s.c.db switch s.cmd { case "WIPE": @@ -377,6 +405,11 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) { } func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) { + err := checkSubsetTypes(args) + if err != nil { + return nil, err + } + db := s.c.db if len(args) != s.placeholders { panic("error in pkg db; should only get here if size is correct") diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go index 291af7f67d..d3677afb3b 100644 --- a/src/pkg/exp/sql/sql.go +++ b/src/pkg/exp/sql/sql.go @@ -88,8 +88,9 @@ type DB struct { driver driver.Driver dsn string - mu sync.Mutex + mu sync.Mutex // protects freeConn and closed freeConn []driver.Conn + closed bool } // Open opens a database specified by its database driver name and a @@ -106,6 +107,22 @@ func Open(driverName, dataSourceName string) (*DB, error) { return &DB{driver: driver, dsn: dataSourceName}, nil } +// Close closes the database, releasing any open resources. +func (db *DB) Close() error { + db.mu.Lock() + defer db.mu.Unlock() + var err error + for _, c := range db.freeConn { + err1 := c.Close() + if err1 != nil { + err = err1 + } + } + db.freeConn = nil + db.closed = true + return err +} + func (db *DB) maxIdleConns() int { const defaultMaxIdleConns = 2 // TODO(bradfitz): ask driver, if supported, for its default preference @@ -116,6 +133,9 @@ func (db *DB) maxIdleConns() int { // conn returns a newly-opened or cached driver.Conn func (db *DB) conn() (driver.Conn, error) { db.mu.Lock() + if db.closed { + return nil, errors.New("sql: database is closed") + } if n := len(db.freeConn); n > 0 { conn := db.freeConn[n-1] db.freeConn = db.freeConn[:n-1] @@ -140,11 +160,13 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { } func (db *DB) putConn(c driver.Conn) { - if n := len(db.freeConn); n < db.maxIdleConns() { + db.mu.Lock() + defer db.mu.Unlock() + if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() { db.freeConn = append(db.freeConn, c) return } - db.closeConn(c) + db.closeConn(c) // TODO(bradfitz): release lock before calling this? } func (db *DB) closeConn(c driver.Conn) { @@ -180,17 +202,11 @@ func (db *DB) Prepare(query string) (*Stmt, error) { // Exec executes a query without returning any rows. func (db *DB) Exec(query string, args ...interface{}) (Result, error) { - // Optional fast path, if the driver implements driver.Execer. - if execer, ok := db.driver.(driver.Execer); ok { - resi, err := execer.Exec(query, args) - if err != nil { - return nil, err - } - return result{resi}, nil + sargs, err := subsetTypeArgs(args) + if err != nil { + return nil, err } - // If the driver does not implement driver.Execer, we need - // a connection. ci, err := db.conn() if err != nil { return nil, err @@ -198,11 +214,13 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { defer db.putConn(ci) if execer, ok := ci.(driver.Execer); ok { - resi, err := execer.Exec(query, args) - if err != nil { - return nil, err + resi, err := execer.Exec(query, sargs) + if err != driver.ErrSkip { + if err != nil { + return nil, err + } + return result{resi}, nil } - return result{resi}, nil } sti, err := ci.Prepare(query) @@ -210,7 +228,8 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { return nil, err } defer sti.Close() - resi, err := sti.Exec(args) + + resi, err := sti.Exec(sargs) if err != nil { return nil, err } @@ -386,7 +405,13 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { return nil, err } defer sti.Close() - resi, err := sti.Exec(args) + + sargs, err := subsetTypeArgs(args) + if err != nil { + return nil, err + } + + resi, err := sti.Exec(sargs) if err != nil { return nil, err } @@ -548,7 +573,11 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { if len(args) != si.NumInput() { return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args)) } - rowsi, err := si.Query(args) + sargs, err := subsetTypeArgs(args) + if err != nil { + return nil, err + } + rowsi, err := si.Query(sargs) if err != nil { s.db.putConn(ci) return nil, err diff --git a/src/pkg/exp/sql/sql_test.go b/src/pkg/exp/sql/sql_test.go index eb1bb58966..d365f6ba19 100644 --- a/src/pkg/exp/sql/sql_test.go +++ b/src/pkg/exp/sql/sql_test.go @@ -34,8 +34,16 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) { } } +func closeDB(t *testing.T, db *DB) { + err := db.Close() + if err != nil { + t.Fatalf("error closing DB: %v", err) + } +} + func TestQuery(t *testing.T) { db := newTestDB(t, "people") + defer closeDB(t, db) var name string var age int @@ -69,6 +77,7 @@ func TestQuery(t *testing.T) { func TestStatementQueryRow(t *testing.T) { db := newTestDB(t, "people") + defer closeDB(t, db) stmt, err := db.Prepare("SELECT|people|age|name=?") if err != nil { t.Fatalf("Prepare: %v", err) @@ -94,6 +103,7 @@ func TestStatementQueryRow(t *testing.T) { // just a test of fakedb itself func TestBogusPreboundParameters(t *testing.T) { db := newTestDB(t, "foo") + defer closeDB(t, db) exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion") if err == nil { @@ -106,6 +116,7 @@ func TestBogusPreboundParameters(t *testing.T) { func TestDb(t *testing.T) { db := newTestDB(t, "foo") + defer closeDB(t, db) exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") stmt, err := db.Prepare("INSERT|t1|name=?,age=?") if err != nil { -- 2.50.0