From: Brad Fitzpatrick Date: Tue, 11 Oct 2016 15:34:58 +0000 (-0700) Subject: database/sql: accept nil pointers to Valuers implemented on value receivers X-Git-Tag: go1.8beta1~842 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=0ce1d79a6a771f7449ec493b993ed2a720917870;p=gostls13.git database/sql: accept nil pointers to Valuers implemented on value receivers The driver.Valuer interface lets types map their Go representation to a suitable database/sql/driver.Value. If a user defines the Value method with a value receiver, such as: type MyStr string func (s MyStr) Value() (driver.Value, error) { return strings.ToUpper(string(s)), nil } Then they can't use (*MyStr)(nil) as an argument to an SQL call via database/sql, because *MyStr also implements driver.Value, but via a compiler-generated wrapper which checks whether the pointer is nil and panics if so. We now accept (*MyStr)(nil) and map it to "nil" (an SQL "NULL") if the Valuer method is implemented on MyStr instead of *MyStr. If a user implements the driver.Value interface with a pointer receiver, they retain full control of what nil means: type MyStr string func (s *MyStr) Value() (driver.Value, error) { if s == nil { return "missing MyStr", nil } return strings.ToUpper(string(*s)), nil } Adds tests for both cases. Fixes #8415 Change-Id: I897d609d80d46e2354d2669a8a3e090688eee3ad Reviewed-on: https://go-review.googlesource.com/31259 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Daniel Theophanes Reviewed-by: Brad Fitzpatrick --- diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go index cee96319da..594a816966 100644 --- a/src/database/sql/convert.go +++ b/src/database/sql/convert.go @@ -57,8 +57,8 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) // First, see if the value itself knows how to convert // itself to a driver type. For example, a NullString // struct changing into a string or nil. - if svi, ok := arg.(driver.Valuer); ok { - sv, err := svi.Value() + if vr, ok := arg.(driver.Valuer); ok { + sv, err := callValuerValue(vr) if err != nil { return nil, fmt.Errorf("sql: argument index %d from Value: %v", n, err) } @@ -341,3 +341,25 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { } return } + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// Issue 8415. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This function is mirrored in the database/sql/driver package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} diff --git a/src/database/sql/convert_test.go b/src/database/sql/convert_test.go index ab81f2f65a..4dfab1f6be 100644 --- a/src/database/sql/convert_test.go +++ b/src/database/sql/convert_test.go @@ -9,6 +9,7 @@ import ( "fmt" "reflect" "runtime" + "strings" "testing" "time" ) @@ -389,3 +390,85 @@ func TestUserDefinedBytes(t *testing.T) { t.Fatal("userDefinedBytes got potentially dirty driver memory") } } + +type Valuer_V string + +func (v Valuer_V) Value() (driver.Value, error) { + return strings.ToUpper(string(v)), nil +} + +type Valuer_P string + +func (p *Valuer_P) Value() (driver.Value, error) { + if p == nil { + return "nil-to-str", nil + } + return strings.ToUpper(string(*p)), nil +} + +func TestDriverArgs(t *testing.T) { + var nilValuerVPtr *Valuer_V + var nilValuerPPtr *Valuer_P + var nilStrPtr *string + tests := []struct { + args []interface{} + want []driver.NamedValue + }{ + 0: { + args: []interface{}{Valuer_V("foo")}, + want: []driver.NamedValue{ + driver.NamedValue{ + Ordinal: 1, + Value: "FOO", + }, + }, + }, + 1: { + args: []interface{}{nilValuerVPtr}, + want: []driver.NamedValue{ + driver.NamedValue{ + Ordinal: 1, + Value: nil, + }, + }, + }, + 2: { + args: []interface{}{nilValuerPPtr}, + want: []driver.NamedValue{ + driver.NamedValue{ + Ordinal: 1, + Value: "nil-to-str", + }, + }, + }, + 3: { + args: []interface{}{"plain-str"}, + want: []driver.NamedValue{ + driver.NamedValue{ + Ordinal: 1, + Value: "plain-str", + }, + }, + }, + 4: { + args: []interface{}{nilStrPtr}, + want: []driver.NamedValue{ + driver.NamedValue{ + Ordinal: 1, + Value: nil, + }, + }, + }, + } + for i, tt := range tests { + ds := new(driverStmt) + got, err := driverArgs(ds, tt.args) + if err != nil { + t.Errorf("test[%d]: %v", i, err) + continue + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("test[%d]: got %v, want %v", i, got, tt.want) + } + } +} diff --git a/src/database/sql/driver/types.go b/src/database/sql/driver/types.go index e480e701a4..c93c97a392 100644 --- a/src/database/sql/driver/types.go +++ b/src/database/sql/driver/types.go @@ -208,13 +208,35 @@ type defaultConverter struct{} var _ ValueConverter = defaultConverter{} +var valuerReflectType = reflect.TypeOf((*Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// Issue 8415. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This function is mirrored in the database/sql package. +func callValuerValue(vr Valuer) (v Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} + func (defaultConverter) ConvertValue(v interface{}) (Value, error) { if IsValue(v) { return v, nil } - if svi, ok := v.(Valuer); ok { - sv, err := svi.Value() + if vr, ok := v.(Valuer); ok { + sv, err := callValuerValue(vr) if err != nil { return nil, err }