]> Cypherpunks repositories - gostls13.git/commitdiff
exp/sql: finish transactions, flesh out types, docs
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 2 Nov 2011 18:46:04 +0000 (11:46 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 2 Nov 2011 18:46:04 +0000 (11:46 -0700)
Fixes #2328 (float, bool)

R=rsc, r
CC=golang-dev
https://golang.org/cl/5294067

src/pkg/Makefile
src/pkg/exp/sql/convert.go
src/pkg/exp/sql/convert_test.go
src/pkg/exp/sql/driver/driver.go
src/pkg/exp/sql/driver/types.go
src/pkg/exp/sql/driver/types_test.go [new file with mode: 0644]
src/pkg/exp/sql/fakedb_test.go
src/pkg/exp/sql/sql.go

index f23f7fc4eda7e8db518ff1b302da69d4ca9db275..3d11502f241324cc1154338529824f3f4fc79531 100644 (file)
@@ -203,7 +203,6 @@ NOTEST+=\
        exp/ebnflint\
        exp/gui\
        exp/gui/x11\
-       exp/sql/driver\
        go/doc\
        hash\
        http/pprof\
index b1feef0eb828329f2508462161ec148048c1ed1f..e46cebe9a3da51f5e8901b41dfad61222209e52e 100644 (file)
@@ -8,6 +8,7 @@ package sql
 
 import (
        "errors"
+       "exp/sql/driver"
        "fmt"
        "reflect"
        "strconv"
@@ -36,10 +37,11 @@ func convertAssign(dest, src interface{}) error {
                }
        }
 
-       sv := reflect.ValueOf(src)
+       var sv reflect.Value
 
        switch d := dest.(type) {
        case *string:
+               sv = reflect.ValueOf(src)
                switch sv.Kind() {
                case reflect.Bool,
                        reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
@@ -48,6 +50,12 @@ func convertAssign(dest, src interface{}) error {
                        *d = fmt.Sprintf("%v", src)
                        return nil
                }
+       case *bool:
+               bv, err := driver.Bool.ConvertValue(src)
+               if err == nil {
+                       *d = bv.(bool)
+               }
+               return err
        }
 
        if scanner, ok := dest.(ScannerInto); ok {
@@ -59,6 +67,10 @@ func convertAssign(dest, src interface{}) error {
                return errors.New("destination not a pointer")
        }
 
+       if !sv.IsValid() {
+               sv = reflect.ValueOf(src)
+       }
+
        dv := reflect.Indirect(dpv)
        if dv.Kind() == sv.Kind() {
                dv.Set(sv)
@@ -67,40 +79,49 @@ func convertAssign(dest, src interface{}) error {
 
        switch dv.Kind() {
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-               if s, ok := asString(src); ok {
-                       i64, err := strconv.Atoi64(s)
-                       if err != nil {
-                               return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
-                       }
-                       if dv.OverflowInt(i64) {
-                               return fmt.Errorf("string %q overflows %s", s, dv.Kind())
-                       }
-                       dv.SetInt(i64)
-                       return nil
+               s := asString(src)
+               i64, err := strconv.Atoi64(s)
+               if err != nil {
+                       return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
+               }
+               if dv.OverflowInt(i64) {
+                       return fmt.Errorf("string %q overflows %s", s, dv.Kind())
                }
+               dv.SetInt(i64)
+               return nil
        case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
-               if s, ok := asString(src); ok {
-                       u64, err := strconv.Atoui64(s)
-                       if err != nil {
-                               return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
-                       }
-                       if dv.OverflowUint(u64) {
-                               return fmt.Errorf("string %q overflows %s", s, dv.Kind())
-                       }
-                       dv.SetUint(u64)
-                       return nil
+               s := asString(src)
+               u64, err := strconv.Atoui64(s)
+               if err != nil {
+                       return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
+               }
+               if dv.OverflowUint(u64) {
+                       return fmt.Errorf("string %q overflows %s", s, dv.Kind())
+               }
+               dv.SetUint(u64)
+               return nil
+       case reflect.Float32, reflect.Float64:
+               s := asString(src)
+               f64, err := strconv.Atof64(s)
+               if err != nil {
+                       return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
+               }
+               if dv.OverflowFloat(f64) {
+                       return fmt.Errorf("value %q overflows %s", s, dv.Kind())
                }
+               dv.SetFloat(f64)
+               return nil
        }
 
        return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
 }
 
-func asString(src interface{}) (s string, ok bool) {
+func asString(src interface{}) string {
        switch v := src.(type) {
        case string:
-               return v, true
+               return v
        case []byte:
-               return string(v), true
+               return string(v)
        }
-       return "", false
+       return fmt.Sprintf("%v", src)
 }
index f85ed99978d53f6a69c9428840d4d2c9b09f0ae6..52cee9272412d697a7a33cb4ba9b54de2ebd337b 100644 (file)
@@ -17,6 +17,9 @@ type conversionTest struct {
        wantint  int64
        wantuint uint64
        wantstr  string
+       wantf32  float32
+       wantf64  float64
+       wantbool bool // used if d is of type *bool
        wanterr  string
 }
 
@@ -29,6 +32,9 @@ var (
        scanint32  int32
        scanuint8  uint8
        scanuint16 uint16
+       scanbool   bool
+       scanf32    float32
+       scanf64    float64
 )
 
 var conversionTests = []conversionTest{
@@ -53,6 +59,35 @@ var conversionTests = []conversionTest{
        {s: "256", d: &scanuint16, wantuint: 256},
        {s: "-1", d: &scanint, wantint: -1},
        {s: "foo", d: &scanint, wanterr: `converting string "foo" to a int: parsing "foo": invalid syntax`},
+
+       // True bools
+       {s: true, d: &scanbool, wantbool: true},
+       {s: "True", d: &scanbool, wantbool: true},
+       {s: "TRUE", d: &scanbool, wantbool: true},
+       {s: "1", d: &scanbool, wantbool: true},
+       {s: 1, d: &scanbool, wantbool: true},
+       {s: int64(1), d: &scanbool, wantbool: true},
+       {s: uint16(1), d: &scanbool, wantbool: true},
+
+       // False bools
+       {s: false, d: &scanbool, wantbool: false},
+       {s: "false", d: &scanbool, wantbool: false},
+       {s: "FALSE", d: &scanbool, wantbool: false},
+       {s: "0", d: &scanbool, wantbool: false},
+       {s: 0, d: &scanbool, wantbool: false},
+       {s: int64(0), d: &scanbool, wantbool: false},
+       {s: uint16(0), d: &scanbool, wantbool: false},
+
+       // Not bools
+       {s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`},
+       {s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`},
+
+       // Floats
+       {s: float64(1.5), d: &scanf64, wantf64: float64(1.5)},
+       {s: int64(1), d: &scanf64, wantf64: float64(1)},
+       {s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
+       {s: "1.5", d: &scanf32, wantf32: float32(1.5)},
+       {s: "1.5", d: &scanf64, wantf64: float64(1.5)},
 }
 
 func intValue(intptr interface{}) int64 {
@@ -63,6 +98,14 @@ func uintValue(intptr interface{}) uint64 {
        return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
 }
 
+func float64Value(ptr interface{}) float64 {
+       return *(ptr.(*float64))
+}
+
+func float32Value(ptr interface{}) float32 {
+       return *(ptr.(*float32))
+}
+
 func TestConversions(t *testing.T) {
        for n, ct := range conversionTests {
                err := convertAssign(ct.d, ct.s)
@@ -86,6 +129,15 @@ func TestConversions(t *testing.T) {
                if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
                        errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
                }
+               if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) {
+                       errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d))
+               }
+               if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) {
+                       errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d))
+               }
+               if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
+                       errf("want bool %v, got %v", ct.wantbool, *bp)
+               }
        }
 }
 
index 52714e817a3007e562a40e0f8d1f00240b4e1a8e..6a51c342415a33bec30bf2e4ba4ee7e5e7a51432 100644 (file)
@@ -24,9 +24,13 @@ import "errors"
 // Driver is the interface that must be implemented by a database
 // driver.
 type Driver interface {
-       // Open returns a new or cached connection to the database.
+       // Open returns a new connection to the database.
        // The name is a string in a driver-specific format.
        //
+       // Open may return a cached connection (one previously
+       // closed), but doing so is unnecessary; the sql package
+       // maintains a pool of idle connections for efficient re-use.
+       //
        // The returned connection is only used by one goroutine at a
        // time.
        Open(name string) (Conn, error)
@@ -59,8 +63,12 @@ type Conn interface {
 
        // Close invalidates and potentially stops any current
        // prepared statements and transactions, marking this
-       // connection as no longer in use.  The driver may cache or
-       // close its underlying connection to its database.
+       // connection as no longer in use.
+       //
+       // Because the sql package maintains a free pool of
+       // connections and only calls Close when there's a surplus of
+       // idle connections, it shouldn't be necessary for drivers to
+       // do their own connection caching.
        Close() error
 
        // Begin starts and returns a new transaction.
index 9faf32f671ab60dc02c8403b9099abd8d4bc4131..6e0ce4339cc86faae5ac60d83616f7a13a05e7a6 100644 (file)
@@ -11,6 +11,21 @@ import (
 )
 
 // ValueConverter is the interface providing the ConvertValue method.
+//
+// Various implementations of ValueConverter are provided by the
+// driver package to provide consistent implementations of conversions
+// between drivers.  The ValueConverters have several uses:
+//
+//  * converting from the subset types as provided by the sql package
+//    into a database table's specific column type and making sure it
+//    fits, such as making sure a particular int64 fits in a
+//    table's uint16 column.
+//
+//  * converting a value as given from the database into one of the
+//    subset types.
+//
+//  * by the sql package, for converting from a driver's subset type
+//    to a user's type in a scan.
 type ValueConverter interface {
        // ConvertValue converts a value to a restricted subset type.
        ConvertValue(v interface{}) (interface{}, error)
@@ -19,15 +34,56 @@ type ValueConverter interface {
 // Bool is a ValueConverter that converts input values to bools.
 //
 // The conversion rules are:
-//  - .... TODO(bradfitz): TBD
+//  - booleans are returned unchanged
+//  - for integer types,
+//       1 is true
+//       0 is false,
+//       other integers are an error
+//  - for strings and []byte, same rules as strconv.Atob
+//  - all other types are an error
 var Bool boolType
 
 type boolType struct{}
 
 var _ ValueConverter = boolType{}
 
-func (boolType) ConvertValue(v interface{}) (interface{}, error) {
-       return nil, fmt.Errorf("TODO(bradfitz): bool conversions")
+func (boolType) String() string { return "Bool" }
+
+func (boolType) ConvertValue(src interface{}) (interface{}, error) {
+       switch s := src.(type) {
+       case bool:
+               return s, nil
+       case string:
+               b, err := strconv.Atob(s)
+               if err != nil {
+                       return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
+               }
+               return b, nil
+       case []byte:
+               b, err := strconv.Atob(string(s))
+               if err != nil {
+                       return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
+               }
+               return b, nil
+       }
+
+       sv := reflect.ValueOf(src)
+       switch sv.Kind() {
+       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+               iv := sv.Int()
+               if iv == 1 || iv == 0 {
+                       return iv == 1, nil
+               }
+               return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", iv)
+       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+               uv := sv.Uint()
+               if uv == 1 || uv == 0 {
+                       return uv == 1, nil
+               }
+               return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", uv)
+       }
+
+       return nil, fmt.Errorf("sql/driver: couldn't convert %v (%T) into type bool", src, src)
 }
 
 // Int32 is a ValueConverter that converts input values to int64,
diff --git a/src/pkg/exp/sql/driver/types_test.go b/src/pkg/exp/sql/driver/types_test.go
new file mode 100644 (file)
index 0000000..4b049e2
--- /dev/null
@@ -0,0 +1,57 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package driver
+
+import (
+       "reflect"
+       "testing"
+)
+
+type valueConverterTest struct {
+       c   ValueConverter
+       in  interface{}
+       out interface{}
+       err string
+}
+
+var valueConverterTests = []valueConverterTest{
+       {Bool, "true", true, ""},
+       {Bool, "True", true, ""},
+       {Bool, []byte("t"), true, ""},
+       {Bool, true, true, ""},
+       {Bool, "1", true, ""},
+       {Bool, 1, true, ""},
+       {Bool, int64(1), true, ""},
+       {Bool, uint16(1), true, ""},
+       {Bool, "false", false, ""},
+       {Bool, false, false, ""},
+       {Bool, "0", false, ""},
+       {Bool, 0, false, ""},
+       {Bool, int64(0), false, ""},
+       {Bool, uint16(0), false, ""},
+       {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"},
+       {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"},
+}
+
+func TestValueConverters(t *testing.T) {
+       for i, tt := range valueConverterTests {
+               out, err := tt.c.ConvertValue(tt.in)
+               goterr := ""
+               if err != nil {
+                       goterr = err.Error()
+               }
+               if goterr != tt.err {
+                       t.Errorf("test %d: %s(%T(%v)) error = %q; want error = %q",
+                               i, tt.c, tt.in, tt.in, goterr, tt.err)
+               }
+               if tt.err != "" {
+                       continue
+               }
+               if !reflect.DeepEqual(out, tt.out) {
+                       t.Errorf("test %d: %s(%T(%v)) = %v (%T); want %v (%T)",
+                               i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
+               }
+       }
+}
index 289294bee265a6806f6940a0a7d08cb15dcd7b0f..c8a19974d641df1317c832829ba7929a7a71df9c 100644 (file)
@@ -476,7 +476,7 @@ func (rc *rowsCursor) Next(dest []interface{}) error {
        for i, v := range rc.rows[rc.pos].cols {
                // TODO(bradfitz): convert to subset types? naah, I
                // think the subset types should only be input to
-               // driver, but the db package should be able to handle
+               // driver, but the sql package should be able to handle
                // a wider range of types coming out of drivers. all
                // for ease of drivers, and to prevent drivers from
                // messing up conversions or doing them differently.
index 4f1c539127cf51d23ebe50895315d8582431938a..1af8e063cf3ed20bdb3fa793db1e9afa464d9f87 100644 (file)
@@ -10,7 +10,6 @@ import (
        "errors"
        "fmt"
        "io"
-       "runtime"
        "sync"
 
        "exp/sql/driver"
@@ -192,13 +191,13 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
 
        // If the driver does not implement driver.Execer, we need
        // a connection.
-       conn, err := db.conn()
+       ci, err := db.conn()
        if err != nil {
                return nil, err
        }
-       defer db.putConn(conn)
+       defer db.putConn(ci)
 
-       if execer, ok := conn.(driver.Execer); ok {
+       if execer, ok := ci.(driver.Execer); ok {
                resi, err := execer.Exec(query, args)
                if err != nil {
                        return nil, err
@@ -206,7 +205,7 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
                return result{resi}, nil
        }
 
-       sti, err := conn.Prepare(query)
+       sti, err := ci.Prepare(query)
        if err != nil {
                return nil, err
        }
@@ -233,18 +232,26 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
 // Row's Scan method is called.
 func (db *DB) QueryRow(query string, args ...interface{}) *Row {
        rows, err := db.Query(query, args...)
-       if err != nil {
-               return &Row{err: err}
-       }
-       return &Row{rows: rows}
+       return &Row{rows: rows, err: err}
 }
 
-// Begin starts a transaction.  The isolation level is dependent on
+// Begin starts a transaction. The isolation level is dependent on
 // the driver.
 func (db *DB) Begin() (*Tx, error) {
-       // TODO(bradfitz): add another method for beginning a transaction
-       // at a specific isolation level.
-       panic(todo())
+       ci, err := db.conn()
+       if err != nil {
+               return nil, err
+       }
+       txi, err := ci.Begin()
+       if err != nil {
+               db.putConn(ci)
+               return nil, fmt.Errorf("sql: failed to Begin transaction: %v", err)
+       }
+       return &Tx{
+               db:  db,
+               ci:  ci,
+               txi: txi,
+       }, nil
 }
 
 // DriverDatabase returns the database's underlying driver.
@@ -253,41 +260,158 @@ func (db *DB) Driver() driver.Driver {
 }
 
 // Tx is an in-progress database transaction.
+//
+// A transaction must end with a call to Commit or Rollback.
+//
+// After a call to Commit or Rollback, all operations on the
+// transaction fail with ErrTransactionFinished.
 type Tx struct {
+       db *DB
+
+       // ci is owned exclusively until Commit or Rollback, at which point
+       // it's returned with putConn.
+       ci  driver.Conn
+       txi driver.Tx
+
+       // cimu is held while somebody is using ci (between grabConn
+       // and releaseConn)
+       cimu sync.Mutex
 
+       // done transitions from false to true exactly once, on Commit
+       // or Rollback. once done, all operations fail with
+       // ErrTransactionFinished.
+       done bool
+}
+
+var ErrTransactionFinished = errors.New("sql: Transaction has already been committed or rolled back")
+
+func (tx *Tx) close() {
+       if tx.done {
+               panic("double close") // internal error
+       }
+       tx.done = true
+       tx.db.putConn(tx.ci)
+       tx.ci = nil
+       tx.txi = nil
+}
+
+func (tx *Tx) grabConn() (driver.Conn, error) {
+       if tx.done {
+               return nil, ErrTransactionFinished
+       }
+       tx.cimu.Lock()
+       return tx.ci, nil
+}
+
+func (tx *Tx) releaseConn() {
+       tx.cimu.Unlock()
 }
 
 // Commit commits the transaction.
 func (tx *Tx) Commit() error {
-       panic(todo())
+       if tx.done {
+               return ErrTransactionFinished
+       }
+       defer tx.close()
+       return tx.txi.Commit()
 }
 
 // Rollback aborts the transaction.
 func (tx *Tx) Rollback() error {
-       panic(todo())
+       if tx.done {
+               return ErrTransactionFinished
+       }
+       defer tx.close()
+       return tx.txi.Rollback()
 }
 
 // Prepare creates a prepared statement.
+//
+// The statement is only valid within the scope of this transaction.
 func (tx *Tx) Prepare(query string) (*Stmt, error) {
-       panic(todo())
+       // TODO(bradfitz): the restriction that the returned statement
+       // is only valid for this Transaction is lame and negates a
+       // lot of the benefit of prepared statements.  We could be
+       // more efficient here and either provide a method to take an
+       // existing Stmt (created on perhaps a different Conn), and
+       // re-create it on this Conn if necessary. Or, better: keep a
+       // map in DB of query string to Stmts, and have Stmt.Execute
+       // do the right thing and re-prepare if the Conn in use
+       // doesn't have that prepared statement.  But we'll want to
+       // avoid caching the statement in the case where we only call
+       // conn.Prepare implicitly (such as in db.Exec or tx.Exec),
+       // but the caller package can't be holding a reference to the
+       // returned statement.  Perhaps just looking at the reference
+       // count (by noting Stmt.Close) would be enough. We might also
+       // want a finalizer on Stmt to drop the reference count.
+       ci, err := tx.grabConn()
+       if err != nil {
+               return nil, err
+       }
+       defer tx.releaseConn()
+
+       si, err := ci.Prepare(query)
+       if err != nil {
+               return nil, err
+       }
+
+       stmt := &Stmt{
+               db:    tx.db,
+               tx:    tx,
+               txsi:  si,
+               query: query,
+       }
+       return stmt, nil
 }
 
 // Exec executes a query that doesn't return rows.
 // For example: an INSERT and UPDATE.
-func (tx *Tx) Exec(query string, args ...interface{}) {
-       panic(todo())
+func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
+       ci, err := tx.grabConn()
+       if err != nil {
+               return nil, err
+       }
+       defer tx.releaseConn()
+
+       if execer, ok := ci.(driver.Execer); ok {
+               resi, err := execer.Exec(query, args)
+               if err != nil {
+                       return nil, err
+               }
+               return result{resi}, nil
+       }
+
+       sti, err := ci.Prepare(query)
+       if err != nil {
+               return nil, err
+       }
+       defer sti.Close()
+       resi, err := sti.Exec(args)
+       if err != nil {
+               return nil, err
+       }
+       return result{resi}, nil
 }
 
 // Query executes a query that returns rows, typically a SELECT.
 func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
-       panic(todo())
+       if tx.done {
+               return nil, ErrTransactionFinished
+       }
+       stmt, err := tx.Prepare(query)
+       if err != nil {
+               return nil, err
+       }
+       defer stmt.Close()
+       return stmt.Query(args...)
 }
 
 // QueryRow executes a query that is expected to return at most one row.
 // QueryRow always return a non-nil value. Errors are deferred until
 // Row's Scan method is called.
 func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
-       panic(todo())
+       rows, err := tx.Query(query, args...)
+       return &Row{rows: rows, err: err}
 }
 
 // connStmt is a prepared statement on a particular connection.
@@ -302,24 +426,28 @@ type Stmt struct {
        db    *DB    // where we came from
        query string // that created the Sttm
 
-       mu     sync.Mutex
+       // If in a transaction, else both nil:
+       tx   *Tx
+       txsi driver.Stmt
+
+       mu     sync.Mutex // protects the rest of the fields
        closed bool
-       css    []connStmt // can use any that have idle connections
-}
 
-func todo() string {
-       _, file, line, _ := runtime.Caller(1)
-       return fmt.Sprintf("%s:%d: TODO: implement", file, line)
+       // css is a list of underlying driver statement interfaces
+       // that are valid on particular connections.  This is only
+       // used if tx == nil and one is found that has idle
+       // connections.  If tx != nil, txsi is always used.
+       css []connStmt
 }
 
 // Exec executes a prepared statement with the given arguments and
 // returns a Result summarizing the effect of the statement.
 func (s *Stmt) Exec(args ...interface{}) (Result, error) {
-       ci, si, err := s.connStmt()
+       _, releaseConn, si, err := s.connStmt()
        if err != nil {
                return nil, err
        }
-       defer s.db.putConn(ci)
+       defer releaseConn()
 
        if want := si.NumInput(); len(args) != want {
                return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args))
@@ -353,11 +481,29 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
        return result{resi}, nil
 }
 
-func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, error) {
+// connStmt returns a free driver connection on which to execute the
+// statement, a function to call to release the connection, and a
+// statement bound to that connection.
+func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
        s.mu.Lock()
        if s.closed {
-               return nil, nil, errors.New("db: statement is closed")
+               s.mu.Unlock()
+               err = errors.New("db: statement is closed")
+               return
        }
+
+       // In a transaction, we always use the connection that the
+       // transaction was created on.
+       if s.tx != nil {
+               s.mu.Unlock()
+               ci, err = s.tx.grabConn() // blocks, waiting for the connection.
+               if err != nil {
+                       return
+               }
+               releaseConn = func() { s.tx.releaseConn() }
+               return ci, releaseConn, s.txsi, nil
+       }
+
        var cs connStmt
        match := false
        for _, v := range s.css {
@@ -375,11 +521,11 @@ func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, error) {
        if !match {
                ci, err := s.db.conn()
                if err != nil {
-                       return nil, nil, err
+                       return nil, nil, nil, err
                }
                si, err := ci.Prepare(s.query)
                if err != nil {
-                       return nil, nil, err
+                       return nil, nil, nil, err
                }
                s.mu.Lock()
                cs = connStmt{ci, si}
@@ -387,13 +533,15 @@ func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, error) {
                s.mu.Unlock()
        }
 
-       return cs.ci, cs.si, nil
+       conn := cs.ci
+       releaseConn = func() { s.db.putConn(conn) }
+       return conn, releaseConn, cs.si, nil
 }
 
 // Query executes a prepared query statement with the given arguments
 // and returns the query results as a *Rows.
 func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
-       ci, si, err := s.connStmt(args...)
+       ci, releaseConn, si, err := s.connStmt()
        if err != nil {
                return nil, err
        }
@@ -405,11 +553,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
                s.db.putConn(ci)
                return nil, err
        }
-       // Note: ownership of ci passes to the *Rows
+       // Note: ownership of ci passes to the *Rows, to be freed
+       // with releaseConn.
        rows := &Rows{
-               db:    s.db,
-               ci:    ci,
-               rowsi: rowsi,
+               db:          s.db,
+               ci:          ci,
+               releaseConn: releaseConn,
+               rowsi:       rowsi,
        }
        return rows, nil
 }
@@ -436,19 +586,24 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row {
 // Close closes the statement.
 func (s *Stmt) Close() error {
        s.mu.Lock()
-       defer s.mu.Unlock() // TODO(bradfitz): move this unlock after 'closed = true'?
+       defer s.mu.Unlock()
        if s.closed {
                return nil
        }
        s.closed = true
-       for _, v := range s.css {
-               if ci, match := s.db.connIfFree(v.ci); match {
-                       v.si.Close()
-                       s.db.putConn(ci)
-               } else {
-                       // TODO(bradfitz): care that we can't close
-                       // this statement because the statement's
-                       // connection is in use?
+
+       if s.tx != nil {
+               s.txsi.Close()
+       } else {
+               for _, v := range s.css {
+                       if ci, match := s.db.connIfFree(v.ci); match {
+                               v.si.Close()
+                               s.db.putConn(ci)
+                       } else {
+                               // TODO(bradfitz): care that we can't close
+                               // this statement because the statement's
+                               // connection is in use?
+                       }
                }
        }
        return nil
@@ -468,9 +623,10 @@ func (s *Stmt) Close() error {
 //     err = rows.Error() // get any Error encountered during iteration
 //     ...
 type Rows struct {
-       db    *DB
-       ci    driver.Conn // owned; must be returned when Rows is closed
-       rowsi driver.Rows
+       db          *DB
+       ci          driver.Conn // owned; must call putconn when closed to release
+       releaseConn func()
+       rowsi       driver.Rows
 
        closed   bool
        lastcols []interface{}
@@ -538,7 +694,7 @@ func (rs *Rows) Close() error {
        }
        rs.closed = true
        err := rs.rowsi.Close()
-       rs.db.putConn(rs.ci)
+       rs.releaseConn()
        return err
 }