]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: add Conn.Raw to expose the driver Conn safely
authorDaniel Theophanes <kardianos@gmail.com>
Fri, 26 Apr 2019 21:09:07 +0000 (14:09 -0700)
committerDaniel Theophanes <kardianos@gmail.com>
Thu, 13 Jun 2019 16:49:52 +0000 (16:49 +0000)
Exposing the underlying driver conn will allow the use of the
standard connection pool while still able to run special function
directly on the driver.

Fixes #29835

Change-Id: Ib6d3b9535e730f008916805ae3bf76e4494c88f9
Reviewed-on: https://go-review.googlesource.com/c/go/+/174182
Run-TryBot: Daniel Theophanes <kardianos@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/database/sql/sql.go
src/database/sql/sql_test.go

index 27adf691227cfe7f738a819a6e1a73665daa18bc..5c5b7dc7e973e9486eb27706a0b2f5f35dc6e696 100644 (file)
@@ -1792,6 +1792,8 @@ type Conn struct {
        done int32
 }
 
+// grabConn takes a context to implement stmtConnGrabber
+// but the context is not used.
 func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
        if atomic.LoadInt32(&c.done) != 0 {
                return nil, nil, ErrConnDone
@@ -1856,6 +1858,39 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error)
        return c.db.prepareDC(ctx, dc, release, c, query)
 }
 
+// Raw executes f exposing the underlying driver connection for the
+// duration of f. The driverConn must not be used outside of f.
+//
+// Once f returns and err is nil, the Conn will continue to be usable
+// until Conn.Close is called.
+func (c *Conn) Raw(f func(driverConn interface{}) error) (err error) {
+       var dc *driverConn
+       var release releaseConn
+
+       // grabConn takes a context to implement stmtConnGrabber, but the context is not used.
+       dc, release, err = c.grabConn(nil)
+       if err != nil {
+               return
+       }
+       fPanic := true
+       dc.Mutex.Lock()
+       defer func() {
+               dc.Mutex.Unlock()
+
+               // If f panics fPanic will remain true.
+               // Ensure an error is passed to release so the connection
+               // may be discarded.
+               if fPanic {
+                       err = driver.ErrBadConn
+               }
+               release(err)
+       }()
+       err = f(dc.ci)
+       fPanic = false
+
+       return
+}
+
 // BeginTx starts a transaction.
 //
 // The provided context is used until the transaction is committed or rolled back.
index 260374d4138a77f403d7fcc80639dbbf975d5a26..a95b70cadb1c8877a2ce9c68cef132ff08a63ce3 100644 (file)
@@ -1339,6 +1339,54 @@ func TestConnQuery(t *testing.T) {
        }
 }
 
+func TestConnRaw(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       conn, err := db.Conn(ctx)
+       if err != nil {
+               t.Fatal(err)
+       }
+       conn.dc.ci.(*fakeConn).skipDirtySession = true
+       defer conn.Close()
+
+       sawFunc := false
+       err = conn.Raw(func(dc interface{}) error {
+               sawFunc = true
+               if _, ok := dc.(*fakeConn); !ok {
+                       return fmt.Errorf("got %T want *fakeConn", dc)
+               }
+               return nil
+       })
+       if err != nil {
+               t.Fatal(err)
+       }
+       if !sawFunc {
+               t.Fatal("Raw func not called")
+       }
+
+       func() {
+               defer func() {
+                       x := recover()
+                       if x == nil {
+                               t.Fatal("expected panic")
+                       }
+                       conn.closemu.Lock()
+                       closed := conn.dc == nil
+                       conn.closemu.Unlock()
+                       if !closed {
+                               t.Fatal("expected connection to be closed after panic")
+                       }
+               }()
+               err = conn.Raw(func(dc interface{}) error {
+                       panic("Conn.Raw panic should return an error")
+               })
+               t.Fatal("expected panic from Raw func")
+       }()
+}
+
 func TestCursorFake(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)