]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: support returning query database types
authorDaniel Theophanes <kardianos@gmail.com>
Tue, 27 Sep 2016 20:27:02 +0000 (13:27 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 18 Oct 2016 10:52:57 +0000 (10:52 +0000)
Creates a ColumnType structure that can be extended in to future.
Allow drivers to implement what makes sense for the database.

Fixes #16652

Change-Id: Ieb1fd64eac1460107b1d3474eba5201fa300a4ec
Reviewed-on: https://go-review.googlesource.com/29961
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/database/sql/driver/driver.go
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index 6cc970f688c446a3041f6ba8a5923b8f2cf57f3a..ad988cc785b9db9969fe2d66de0110ea0d97f467 100644 (file)
@@ -11,6 +11,7 @@ package driver
 import (
        "context"
        "errors"
+       "reflect"
 )
 
 // Value is a value that drivers must be able to handle.
@@ -239,6 +240,60 @@ type RowsNextResultSet interface {
        NextResultSet() error
 }
 
+// RowsColumnTypeScanType may be implemented by Rows. It should return
+// the value type that can be used to scan types into. For example, the database
+// column type "bigint" this should return "reflect.TypeOf(int64(0))".
+type RowsColumnTypeScanType interface {
+       Rows
+       ColumnTypeScanType(index int) reflect.Type
+}
+
+// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
+// database system type name without the length. Type names should be uppercase.
+// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
+// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
+// "TIMESTAMP".
+type RowsColumnTypeDatabaseTypeName interface {
+       Rows
+       ColumnTypeDatabaseTypeName(index int) string
+}
+
+// RowsColumnTypeLength may be implemented by Rows. It should return the length
+// of the column type if the column is a variable length type. If the column is
+// not a variable length type ok should return false.
+// If length is not limited other than system limits, it should return math.MaxInt64.
+// The following are examples of returned values for various types:
+//   TEXT          (math.MaxInt64, true)
+//   varchar(10)   (10, true)
+//   nvarchar(10)  (10, true)
+//   decimal       (0, false)
+//   int           (0, false)
+//   bytea(30)     (30, true)
+type RowsColumnTypeLength interface {
+       Rows
+       ColumnTypeLength(index int) (length int64, ok bool)
+}
+
+// RowsColumnTypeNullable may be implemented by Rows. The nullable value should
+// be true if it is known the column may be null, or false if the column is known
+// to be not nullable.
+// If the column nullability is unknown, ok should be false.
+type RowsColumnTypeNullable interface {
+       Rows
+       ColumnTypeNullable(index int) (nullable, ok bool)
+}
+
+// RowsColumnTypePrecisionScale may be implemented by Rows. It should return
+// the precision and scale for decimal types. If not applicable, ok should be false.
+// The following are examples of returned values for various types:
+//   decimal(38, 4)    (38, 4, true)
+//   int               (0, 0, false)
+//   decimal           (math.MaxInt64, math.MaxInt64, true)
+type RowsColumnTypePrecisionScale interface {
+       Rows
+       ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
+}
+
 // Tx is a transaction.
 type Tx interface {
        Commit() error
index 07f50196a57c751fa4a35121740644ca9b424dec..c42f23208fb0ff5c85d67ec54c38f7a962077137 100644 (file)
@@ -11,6 +11,7 @@ import (
        "fmt"
        "io"
        "log"
+       "reflect"
        "sort"
        "strconv"
        "strings"
@@ -405,6 +406,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err
                return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
        }
        stmt.table = parts[0]
+
        stmt.colName = strings.Split(parts[1], ",")
        for n, colspec := range strings.Split(parts[2], ",") {
                if colspec == "" {
@@ -725,6 +727,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
 
        setMRows := make([][]*row, 0, 1)
        setColumns := make([][]string, 0, 1)
+       setColType := make([][]string, 0, 1)
 
        for {
                db.mu.Lock()
@@ -794,10 +797,16 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
                        mrows = append(mrows, mrow)
                }
 
+               var colType []string
+               for _, column := range s.colName {
+                       colType = append(colType, t.coltype[t.columnIndex(column)])
+               }
+
                t.mu.Unlock()
 
                setMRows = append(setMRows, mrows)
                setColumns = append(setColumns, s.colName)
+               setColType = append(setColType, colType)
 
                if s.next == nil {
                        break
@@ -806,10 +815,11 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
        }
 
        cursor := &rowsCursor{
-               posRow: -1,
-               rows:   setMRows,
-               cols:   setColumns,
-               errPos: -1,
+               posRow:  -1,
+               rows:    setMRows,
+               cols:    setColumns,
+               colType: setColType,
+               errPos:  -1,
        }
        return cursor, nil
 }
@@ -844,11 +854,12 @@ func (tx *fakeTx) Rollback() error {
 }
 
 type rowsCursor struct {
-       cols   [][]string
-       posSet int
-       posRow int
-       rows   [][]*row
-       closed bool
+       cols    [][]string
+       colType [][]string
+       posSet  int
+       posRow  int
+       rows    [][]*row
+       closed  bool
 
        // errPos and err are for making Next return early with error.
        errPos int
@@ -874,6 +885,10 @@ func (rc *rowsCursor) Columns() []string {
        return rc.cols[rc.posSet]
 }
 
+func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
+       return colTypeToReflectType(rc.colType[rc.posSet][index])
+}
+
 var rowsCursorNextHook func(dest []driver.Value) error
 
 func (rc *rowsCursor) Next(dest []driver.Value) error {
@@ -980,3 +995,29 @@ func converterForType(typ string) driver.ValueConverter {
        }
        panic("invalid fakedb column type of " + typ)
 }
+
+func colTypeToReflectType(typ string) reflect.Type {
+       switch typ {
+       case "bool":
+               return reflect.TypeOf(false)
+       case "nullbool":
+               return reflect.TypeOf(NullBool{})
+       case "int32":
+               return reflect.TypeOf(int32(0))
+       case "string":
+               return reflect.TypeOf("")
+       case "nullstring":
+               return reflect.TypeOf(NullString{})
+       case "int64":
+               return reflect.TypeOf(int64(0))
+       case "nullint64":
+               return reflect.TypeOf(NullInt64{})
+       case "float64":
+               return reflect.TypeOf(float64(0))
+       case "nullfloat64":
+               return reflect.TypeOf(NullFloat64{})
+       case "datetime":
+               return reflect.TypeOf(time.Time{})
+       }
+       panic("invalid fakedb column type of " + typ)
+}
index 616acb2be1a5210e71b0960a48675afdf6a5629c..defe9607420bba0c831e5937b0853d36cf6ab50c 100644 (file)
@@ -18,6 +18,7 @@ import (
        "errors"
        "fmt"
        "io"
+       "reflect"
        "runtime"
        "sort"
        "sync"
@@ -996,8 +997,8 @@ const maxBadConnRetries = 2
 // The caller must call the statement's Close method
 // when the statement is no longer needed.
 //
-// The provided context is for the preparation of the statment, not for the execution of
-// the statement.
+// The provided context is used for the preparation of the statement, not for the
+// execution of the statement.
 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
        var stmt *Stmt
        var err error
@@ -2033,6 +2034,107 @@ func (rs *Rows) Columns() ([]string, error) {
        return rs.rowsi.Columns(), nil
 }
 
+// ColumnTypes returns column information such as column type, length,
+// and nullable. Some information may not be available from some drivers.
+func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
+       if rs.isClosed() {
+               return nil, errors.New("sql: Rows are closed")
+       }
+       if rs.rowsi == nil {
+               return nil, errors.New("sql: no Rows available")
+       }
+       return rowsColumnInfoSetup(rs.rowsi), nil
+}
+
+// ColumnType contains the name and type of a column.
+type ColumnType struct {
+       name string
+
+       hasNullable       bool
+       hasLength         bool
+       hasPrecisionScale bool
+
+       nullable     bool
+       length       int64
+       databaseType string
+       precision    int64
+       scale        int64
+       scanType     reflect.Type
+}
+
+// Name returns the name or alias of the column.
+func (ci *ColumnType) Name() string {
+       return ci.name
+}
+
+// Length returns the column type length for variable length column types such
+// as text and binary field types. If the type length is unbounded the value will
+// be math.MaxInt64 (any database limits will still apply).
+// If the column type is not variable length, such as an int, or if not supported
+// by the driver ok is false.
+func (ci *ColumnType) Length() (length int64, ok bool) {
+       return ci.length, ci.hasLength
+}
+
+// DecimalSize returns the scale and precision of a decimal type.
+// If not applicable or if not supported ok is false.
+func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
+       return ci.precision, ci.scale, ci.hasPrecisionScale
+}
+
+// ScanType returns a Go type suitable for scanning into using Rows.Scan.
+// If a driver does not support this property ScanType will return
+// the type of an empty interface.
+func (ci *ColumnType) ScanType() reflect.Type {
+       return ci.scanType
+}
+
+// Nullable returns whether the column may be null.
+// If a driver does not support this property ok will be false.
+func (ci *ColumnType) Nullable() (nullable, ok bool) {
+       return ci.nullable, ci.hasNullable
+}
+
+// DatabaseTypeName returns the database system name of the column type. If an empty
+// string is returned the driver type name is not supported.
+// Consult your driver documentation for a list of driver data types. Length specifiers
+// are not included.
+// Common type include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", "INT", "BIGINT".
+func (ci *ColumnType) DatabaseTypeName() string {
+       return ci.databaseType
+}
+
+func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
+       names := rowsi.Columns()
+
+       list := make([]*ColumnType, len(names))
+       for i := range list {
+               ci := &ColumnType{
+                       name: names[i],
+               }
+               list[i] = ci
+
+               if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
+                       ci.scanType = prop.ColumnTypeScanType(i)
+               } else {
+                       ci.scanType = reflect.TypeOf(new(interface{})).Elem()
+               }
+               if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
+                       ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
+               }
+               if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
+                       ci.length, ci.hasLength = prop.ColumnTypeLength(i)
+               }
+               if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
+                       ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
+               }
+               if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
+                       ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
+               }
+       }
+       return list
+}
+
 // Scan copies the columns in the current row into the values pointed
 // at by dest. The number of values in dest must be the same as the
 // number of columns in Rows.
index 885cadf3c66f9c54c8aceae3629791fb833ce2aa..f4b887ca963735695c1f07287f0e21f1ef490355 100644 (file)
@@ -499,6 +499,56 @@ func TestRowsColumns(t *testing.T) {
        }
 }
 
+func TestRowsColumnTypes(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+       rows, err := db.Query("SELECT|people|age,name|")
+       if err != nil {
+               t.Fatalf("Query: %v", err)
+       }
+       tt, err := rows.ColumnTypes()
+       if err != nil {
+               t.Fatalf("ColumnTypes: %v", err)
+       }
+
+       types := make([]reflect.Type, len(tt))
+       for i, tp := range tt {
+               st := tp.ScanType()
+               if st == nil {
+                       t.Errorf("scantype is null for column %q", tp.Name())
+                       continue
+               }
+               types[i] = st
+       }
+       values := make([]interface{}, len(tt))
+       for i := range values {
+               values[i] = reflect.New(types[i]).Interface()
+       }
+       ct := 0
+       for rows.Next() {
+               err = rows.Scan(values...)
+               if err != nil {
+                       t.Fatalf("failed to scan values in %v", err)
+               }
+               ct++
+               if ct == 0 {
+                       if values[0].(string) != "Bob" {
+                               t.Errorf("Expected Bob, got %v", values[0])
+                       }
+                       if values[1].(int) != 2 {
+                               t.Errorf("Expected 2, got %v", values[1])
+                       }
+               }
+       }
+       if ct != 3 {
+               t.Errorf("expected 3 rows, got %d", ct)
+       }
+
+       if err := rows.Close(); err != nil {
+               t.Errorf("error closing rows: %s", err)
+       }
+}
+
 func TestQueryRow(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)