]> Cypherpunks repositories - gostls13.git/commitdiff
exp/sql: copy when scanning into []byte by default
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 17 Jan 2012 18:44:35 +0000 (10:44 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 17 Jan 2012 18:44:35 +0000 (10:44 -0800)
Fixes #2698

R=rsc
CC=golang-dev
https://golang.org/cl/5539060

src/pkg/exp/sql/sql.go
src/pkg/exp/sql/sql_test.go

index 4e68c3ee0952b6e778a0facf650862212164cbb9..cba7e9ebe57ff0301e652ce78a90d9bb760e15a3 100644 (file)
@@ -30,6 +30,11 @@ func Register(name string, driver driver.Driver) {
        drivers[name] = driver
 }
 
+// RawBytes is a byte slice that holds a reference to memory owned by
+// the database itself. After a Scan into a RawBytes, the slice is only
+// valid until the next call to Next, Scan, or Close.
+type RawBytes []byte
+
 // NullableString represents a string that may be null.
 // NullableString implements the ScannerInto interface so
 // it can be used as a scan destination:
@@ -760,9 +765,13 @@ func (rs *Rows) Columns() ([]string, error) {
 }
 
 // Scan copies the columns in the current row into the values pointed
-// at by dest. If dest contains pointers to []byte, the slices should
-// not be modified and should only be considered valid until the next
-// call to Next or Scan.
+// at by dest.
+//
+// If an argument has type *[]byte, Scan saves in that argument a copy
+// of the corresponding data. The copy is owned by the caller and can
+// be modified and held indefinitely. The copy can be avoided by using
+// an argument of type *RawBytes instead; see the documentation for
+// RawBytes for restrictions on its use.
 func (rs *Rows) Scan(dest ...interface{}) error {
        if rs.closed {
                return errors.New("sql: Rows closed")
@@ -782,6 +791,18 @@ func (rs *Rows) Scan(dest ...interface{}) error {
                        return fmt.Errorf("sql: Scan error on column index %d: %v", i, err)
                }
        }
+       for _, dp := range dest {
+               b, ok := dp.(*[]byte)
+               if !ok {
+                       continue
+               }
+               if _, ok = dp.(*RawBytes); ok {
+                       continue
+               }
+               clone := make([]byte, len(*b))
+               copy(clone, *b)
+               *b = clone
+       }
        return nil
 }
 
@@ -838,6 +859,9 @@ func (r *Row) Scan(dest ...interface{}) error {
        // they were obtained from the network anyway) But for now we
        // don't care.
        for _, dp := range dest {
+               if _, ok := dp.(*RawBytes); ok {
+                       return errors.New("sql: RawBytes isn't allowed on Row.Scan")
+               }
                b, ok := dp.(*[]byte)
                if !ok {
                        continue
index 3f98a8cd9f288ccd9fdc65da3b33c75419afc7b1..30cd97d17681c0c7a40176694412eb14a2da3682 100644 (file)
@@ -76,7 +76,7 @@ func TestQuery(t *testing.T) {
                {age: 3, name: "Chris"},
        }
        if !reflect.DeepEqual(got, want) {
-               t.Logf(" got: %#v\nwant: %#v", got, want)
+               t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
        }
 
        // And verify that the final rows.Next() call, which hit EOF,
@@ -86,6 +86,43 @@ func TestQuery(t *testing.T) {
        }
 }
 
+func TestByteOwnership(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+       rows, err := db.Query("SELECT|people|name,photo|")
+       if err != nil {
+               t.Fatalf("Query: %v", err)
+       }
+       type row struct {
+               name  []byte
+               photo RawBytes
+       }
+       got := []row{}
+       for rows.Next() {
+               var r row
+               err = rows.Scan(&r.name, &r.photo)
+               if err != nil {
+                       t.Fatalf("Scan: %v", err)
+               }
+               got = append(got, r)
+       }
+       corruptMemory := []byte("\xffPHOTO")
+       want := []row{
+               {name: []byte("Alice"), photo: corruptMemory},
+               {name: []byte("Bob"), photo: corruptMemory},
+               {name: []byte("Chris"), photo: corruptMemory},
+       }
+       if !reflect.DeepEqual(got, want) {
+               t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+       }
+
+       var photo RawBytes
+       err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
+       if err == nil {
+               t.Error("want error scanning into RawBytes from QueryRow")
+       }
+}
+
 func TestRowsColumns(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)
@@ -300,6 +337,6 @@ func TestQueryRowClosingStmt(t *testing.T) {
        }
        fakeConn := db.freeConn[0].(*fakeConn)
        if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
-               t.Logf("statement close mismatch: made %d, closed %d", made, closed)
+               t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
        }
 }