From 9263a71b92e84aa34d0e35887d5c23f5a0a21537 Mon Sep 17 00:00:00 2001 From: Ravi Sastry Kadali Date: Sun, 25 Jan 2026 18:29:30 -0800 Subject: [PATCH] database/sql: ensure Null* types have Valid=false when Scan returns error The Scan methods for NullString, NullInt64, NullInt32, NullFloat64, NullBool, and NullTime set Valid=true before calling convertAssign. If convertAssign returns an error, Valid remains true, which creates an inconsistent state where Valid=true but err!=nil. Fix by setting Valid only after successful conversion. Fixes #45662 Change-Id: I855a20abbe517ed017f7c9b8f5603b17bd9d487d Reviewed-on: https://go-review.googlesource.com/c/go/+/739160 Auto-Submit: Sean Liao LUCI-TryBot-Result: Go LUCI Reviewed-by: Dmitri Shuralyov Reviewed-by: Sean Liao Reviewed-by: Michael Pratt --- src/database/sql/sql.go | 35 ++++---- src/database/sql/sql_test.go | 149 +++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 14 deletions(-) diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 4be450ca87..c8ec91c1ec 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -202,8 +202,9 @@ func (ns *NullString) Scan(value any) error { ns.String, ns.Valid = "", false return nil } - ns.Valid = true - return convertAssign(&ns.String, value) + err := convertAssign(&ns.String, value) + ns.Valid = err == nil + return err } // Value implements the [driver.Valuer] interface. @@ -228,8 +229,9 @@ func (n *NullInt64) Scan(value any) error { n.Int64, n.Valid = 0, false return nil } - n.Valid = true - return convertAssign(&n.Int64, value) + err := convertAssign(&n.Int64, value) + n.Valid = err == nil + return err } // Value implements the [driver.Valuer] interface. @@ -254,8 +256,9 @@ func (n *NullInt32) Scan(value any) error { n.Int32, n.Valid = 0, false return nil } - n.Valid = true - return convertAssign(&n.Int32, value) + err := convertAssign(&n.Int32, value) + n.Valid = err == nil + return err } // Value implements the [driver.Valuer] interface. @@ -334,8 +337,9 @@ func (n *NullFloat64) Scan(value any) error { n.Float64, n.Valid = 0, false return nil } - n.Valid = true - return convertAssign(&n.Float64, value) + err := convertAssign(&n.Float64, value) + n.Valid = err == nil + return err } // Value implements the [driver.Valuer] interface. @@ -360,8 +364,9 @@ func (n *NullBool) Scan(value any) error { n.Bool, n.Valid = false, false return nil } - n.Valid = true - return convertAssign(&n.Bool, value) + err := convertAssign(&n.Bool, value) + n.Valid = err == nil + return err } // Value implements the [driver.Valuer] interface. @@ -386,8 +391,9 @@ func (n *NullTime) Scan(value any) error { n.Time, n.Valid = time.Time{}, false return nil } - n.Valid = true - return convertAssign(&n.Time, value) + err := convertAssign(&n.Time, value) + n.Valid = err == nil + return err } // Value implements the [driver.Valuer] interface. @@ -422,8 +428,9 @@ func (n *Null[T]) Scan(value any) error { n.V, n.Valid = *new(T), false return nil } - n.Valid = true - return convertAssign(&n.V, value) + err := convertAssign(&n.V, value) + n.Valid = err == nil + return err } func (n Null[T]) Value() (driver.Value, error) { diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index e8a6560097..5f093a2d6d 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -5086,3 +5086,152 @@ type unknownInputsValueConverter struct{} func (unknownInputsValueConverter) ConvertValue(v any) (driver.Value, error) { return "string", nil } + +func TestNullTypeScanErrorConsistency(t *testing.T) { + // Issue #45662: Null* types should have Valid=false when Scan returns an error. + // Previously, Valid was set to true before convertAssign was called, + // so if conversion failed, Valid would still be true despite the error. + + tests := []struct { + name string + scanner Scanner + input any + wantErr bool + }{ + { + name: "NullInt32 with invalid input", + scanner: &NullInt32{}, + input: []byte("not_a_number"), + wantErr: true, + }, + { + name: "NullInt64 with invalid input", + scanner: &NullInt64{}, + input: []byte("not_a_number"), + wantErr: true, + }, + { + name: "NullFloat64 with invalid input", + scanner: &NullFloat64{}, + input: []byte("not_a_float"), + wantErr: true, + }, + { + name: "NullBool with invalid input", + scanner: &NullBool{}, + input: []byte("not_a_bool"), + wantErr: true, + }, + // Valid cases should still work + { + name: "NullInt32 with valid input", + scanner: &NullInt32{}, + input: int64(42), + wantErr: false, + }, + { + name: "NullInt64 with valid input", + scanner: &NullInt64{}, + input: int64(42), + wantErr: false, + }, + { + name: "NullFloat64 with valid input", + scanner: &NullFloat64{}, + input: float64(3.14), + wantErr: false, + }, + { + name: "NullBool with valid input", + scanner: &NullBool{}, + input: true, + wantErr: false, + }, + { + name: "NullString with valid input", + scanner: &NullString{}, + input: "hello", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.scanner.Scan(tt.input) + + // Check that error matches expectation + if (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + + // The key invariant: Valid should be the opposite of whether we got an error + // (assuming non-nil input) + var valid bool + switch s := tt.scanner.(type) { + case *NullInt32: + valid = s.Valid + case *NullInt64: + valid = s.Valid + case *NullFloat64: + valid = s.Valid + case *NullBool: + valid = s.Valid + case *NullString: + valid = s.Valid + case *NullTime: + valid = s.Valid + } + + if err != nil && valid { + t.Errorf("Scan() returned error but Valid=true; want Valid=false when err!=nil") + } + if err == nil && !valid { + t.Errorf("Scan() returned nil error but Valid=false; want Valid=true when err==nil") + } + }) + } +} + +// TestNullTypeScanNil verifies that scanning nil sets Valid=false without error. +func TestNullTypeScanNil(t *testing.T) { + tests := []struct { + name string + scanner Scanner + }{ + {"NullString", &NullString{String: "preset", Valid: true}}, + {"NullInt64", &NullInt64{Int64: 42, Valid: true}}, + {"NullInt32", &NullInt32{Int32: 42, Valid: true}}, + {"NullFloat64", &NullFloat64{Float64: 3.14, Valid: true}}, + {"NullBool", &NullBool{Bool: true, Valid: true}}, + {"NullTime", &NullTime{Time: time.Now(), Valid: true}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.scanner.Scan(nil) + if err != nil { + t.Errorf("Scan(nil) error = %v; want nil", err) + } + + var valid bool + switch s := tt.scanner.(type) { + case *NullString: + valid = s.Valid + case *NullInt64: + valid = s.Valid + case *NullInt32: + valid = s.Valid + case *NullFloat64: + valid = s.Valid + case *NullBool: + valid = s.Valid + case *NullTime: + valid = s.Valid + } + + if valid { + t.Errorf("Scan(nil) left Valid=true; want Valid=false") + } + }) + } +} -- 2.52.0