From: Daniel Theophanes Date: Tue, 27 Sep 2016 20:27:02 +0000 (-0700) Subject: database/sql: support returning query database types X-Git-Tag: go1.8beta1~805 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=2a85578b0ecd424e95b29d810b7a414a299fd6a7;p=gostls13.git database/sql: support returning query database types Creates a ColumnType structure that can be extended in to future. Allow drivers to implement what makes sense for the database. Fixes #16652 Change-Id: Ieb1fd64eac1460107b1d3474eba5201fa300a4ec Reviewed-on: https://go-review.googlesource.com/29961 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index 6cc970f688..ad988cc785 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -11,6 +11,7 @@ package driver import ( "context" "errors" + "reflect" ) // Value is a value that drivers must be able to handle. @@ -239,6 +240,60 @@ type RowsNextResultSet interface { NextResultSet() error } +// RowsColumnTypeScanType may be implemented by Rows. It should return +// the value type that can be used to scan types into. For example, the database +// column type "bigint" this should return "reflect.TypeOf(int64(0))". +type RowsColumnTypeScanType interface { + Rows + ColumnTypeScanType(index int) reflect.Type +} + +// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the +// database system type name without the length. Type names should be uppercase. +// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", +// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", +// "TIMESTAMP". +type RowsColumnTypeDatabaseTypeName interface { + Rows + ColumnTypeDatabaseTypeName(index int) string +} + +// RowsColumnTypeLength may be implemented by Rows. It should return the length +// of the column type if the column is a variable length type. If the column is +// not a variable length type ok should return false. +// If length is not limited other than system limits, it should return math.MaxInt64. +// The following are examples of returned values for various types: +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) +type RowsColumnTypeLength interface { + Rows + ColumnTypeLength(index int) (length int64, ok bool) +} + +// RowsColumnTypeNullable may be implemented by Rows. The nullable value should +// be true if it is known the column may be null, or false if the column is known +// to be not nullable. +// If the column nullability is unknown, ok should be false. +type RowsColumnTypeNullable interface { + Rows + ColumnTypeNullable(index int) (nullable, ok bool) +} + +// RowsColumnTypePrecisionScale may be implemented by Rows. It should return +// the precision and scale for decimal types. If not applicable, ok should be false. +// The following are examples of returned values for various types: +// decimal(38, 4) (38, 4, true) +// int (0, 0, false) +// decimal (math.MaxInt64, math.MaxInt64, true) +type RowsColumnTypePrecisionScale interface { + Rows + ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) +} + // Tx is a transaction. type Tx interface { Commit() error diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 07f50196a5..c42f23208f 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log" + "reflect" "sort" "strconv" "strings" @@ -405,6 +406,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) } stmt.table = parts[0] + stmt.colName = strings.Split(parts[1], ",") for n, colspec := range strings.Split(parts[2], ",") { if colspec == "" { @@ -725,6 +727,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( setMRows := make([][]*row, 0, 1) setColumns := make([][]string, 0, 1) + setColType := make([][]string, 0, 1) for { db.mu.Lock() @@ -794,10 +797,16 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( mrows = append(mrows, mrow) } + var colType []string + for _, column := range s.colName { + colType = append(colType, t.coltype[t.columnIndex(column)]) + } + t.mu.Unlock() setMRows = append(setMRows, mrows) setColumns = append(setColumns, s.colName) + setColType = append(setColType, colType) if s.next == nil { break @@ -806,10 +815,11 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( } cursor := &rowsCursor{ - posRow: -1, - rows: setMRows, - cols: setColumns, - errPos: -1, + posRow: -1, + rows: setMRows, + cols: setColumns, + colType: setColType, + errPos: -1, } return cursor, nil } @@ -844,11 +854,12 @@ func (tx *fakeTx) Rollback() error { } type rowsCursor struct { - cols [][]string - posSet int - posRow int - rows [][]*row - closed bool + cols [][]string + colType [][]string + posSet int + posRow int + rows [][]*row + closed bool // errPos and err are for making Next return early with error. errPos int @@ -874,6 +885,10 @@ func (rc *rowsCursor) Columns() []string { return rc.cols[rc.posSet] } +func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { + return colTypeToReflectType(rc.colType[rc.posSet][index]) +} + var rowsCursorNextHook func(dest []driver.Value) error func (rc *rowsCursor) Next(dest []driver.Value) error { @@ -980,3 +995,29 @@ func converterForType(typ string) driver.ValueConverter { } panic("invalid fakedb column type of " + typ) } + +func colTypeToReflectType(typ string) reflect.Type { + switch typ { + case "bool": + return reflect.TypeOf(false) + case "nullbool": + return reflect.TypeOf(NullBool{}) + case "int32": + return reflect.TypeOf(int32(0)) + case "string": + return reflect.TypeOf("") + case "nullstring": + return reflect.TypeOf(NullString{}) + case "int64": + return reflect.TypeOf(int64(0)) + case "nullint64": + return reflect.TypeOf(NullInt64{}) + case "float64": + return reflect.TypeOf(float64(0)) + case "nullfloat64": + return reflect.TypeOf(NullFloat64{}) + case "datetime": + return reflect.TypeOf(time.Time{}) + } + panic("invalid fakedb column type of " + typ) +} diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 616acb2be1..defe960742 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "io" + "reflect" "runtime" "sort" "sync" @@ -996,8 +997,8 @@ const maxBadConnRetries = 2 // The caller must call the statement's Close method // when the statement is no longer needed. // -// The provided context is for the preparation of the statment, not for the execution of -// the statement. +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var stmt *Stmt var err error @@ -2033,6 +2034,107 @@ func (rs *Rows) Columns() ([]string, error) { return rs.rowsi.Columns(), nil } +// ColumnTypes returns column information such as column type, length, +// and nullable. Some information may not be available from some drivers. +func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { + if rs.isClosed() { + return nil, errors.New("sql: Rows are closed") + } + if rs.rowsi == nil { + return nil, errors.New("sql: no Rows available") + } + return rowsColumnInfoSetup(rs.rowsi), nil +} + +// ColumnType contains the name and type of a column. +type ColumnType struct { + name string + + hasNullable bool + hasLength bool + hasPrecisionScale bool + + nullable bool + length int64 + databaseType string + precision int64 + scale int64 + scanType reflect.Type +} + +// Name returns the name or alias of the column. +func (ci *ColumnType) Name() string { + return ci.name +} + +// Length returns the column type length for variable length column types such +// as text and binary field types. If the type length is unbounded the value will +// be math.MaxInt64 (any database limits will still apply). +// If the column type is not variable length, such as an int, or if not supported +// by the driver ok is false. +func (ci *ColumnType) Length() (length int64, ok bool) { + return ci.length, ci.hasLength +} + +// DecimalSize returns the scale and precision of a decimal type. +// If not applicable or if not supported ok is false. +func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) { + return ci.precision, ci.scale, ci.hasPrecisionScale +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +// If a driver does not support this property ScanType will return +// the type of an empty interface. +func (ci *ColumnType) ScanType() reflect.Type { + return ci.scanType +} + +// Nullable returns whether the column may be null. +// If a driver does not support this property ok will be false. +func (ci *ColumnType) Nullable() (nullable, ok bool) { + return ci.nullable, ci.hasNullable +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", "INT", "BIGINT". +func (ci *ColumnType) DatabaseTypeName() string { + return ci.databaseType +} + +func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType { + names := rowsi.Columns() + + list := make([]*ColumnType, len(names)) + for i := range list { + ci := &ColumnType{ + name: names[i], + } + list[i] = ci + + if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok { + ci.scanType = prop.ColumnTypeScanType(i) + } else { + ci.scanType = reflect.TypeOf(new(interface{})).Elem() + } + if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok { + ci.databaseType = prop.ColumnTypeDatabaseTypeName(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok { + ci.length, ci.hasLength = prop.ColumnTypeLength(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok { + ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok { + ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i) + } + } + return list +} + // Scan copies the columns in the current row into the values pointed // at by dest. The number of values in dest must be the same as the // number of columns in Rows. diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 885cadf3c6..f4b887ca96 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -499,6 +499,56 @@ func TestRowsColumns(t *testing.T) { } } +func TestRowsColumnTypes(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + types[i] = st + } + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + ct := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + ct++ + if ct == 0 { + if values[0].(string) != "Bob" { + t.Errorf("Expected Bob, got %v", values[0]) + } + if values[1].(int) != 2 { + t.Errorf("Expected 2, got %v", values[1]) + } + } + } + if ct != 3 { + t.Errorf("expected 3 rows, got %d", ct) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } +} + func TestQueryRow(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db)