]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: ensure Null* types have Valid=false when Scan returns error
authorRavi Sastry Kadali <ravisastryk@gmail.com>
Mon, 26 Jan 2026 02:29:30 +0000 (18:29 -0800)
committerGopher Robot <gobot@golang.org>
Fri, 6 Feb 2026 20:10:59 +0000 (12:10 -0800)
The Scan methods for NullString, NullInt64, NullInt32, NullFloat64,
NullBool, and NullTime set Valid=true before calling convertAssign.
If convertAssign returns an error, Valid remains true, which creates
an inconsistent state where Valid=true but err!=nil.

Fix by setting Valid only after successful conversion.

Fixes #45662

Change-Id: I855a20abbe517ed017f7c9b8f5603b17bd9d487d
Reviewed-on: https://go-review.googlesource.com/c/go/+/739160
Auto-Submit: Sean Liao <sean@liao.dev>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
Reviewed-by: Michael Pratt <mpratt@google.com>
src/database/sql/sql.go
src/database/sql/sql_test.go

index 4be450ca87687728797bafdb10e976adc5a71f58..c8ec91c1ec851623216c5543de165281ab5d4d82 100644 (file)
@@ -202,8 +202,9 @@ func (ns *NullString) Scan(value any) error {
                ns.String, ns.Valid = "", false
                return nil
        }
-       ns.Valid = true
-       return convertAssign(&ns.String, value)
+       err := convertAssign(&ns.String, value)
+       ns.Valid = err == nil
+       return err
 }
 
 // Value implements the [driver.Valuer] interface.
@@ -228,8 +229,9 @@ func (n *NullInt64) Scan(value any) error {
                n.Int64, n.Valid = 0, false
                return nil
        }
-       n.Valid = true
-       return convertAssign(&n.Int64, value)
+       err := convertAssign(&n.Int64, value)
+       n.Valid = err == nil
+       return err
 }
 
 // Value implements the [driver.Valuer] interface.
@@ -254,8 +256,9 @@ func (n *NullInt32) Scan(value any) error {
                n.Int32, n.Valid = 0, false
                return nil
        }
-       n.Valid = true
-       return convertAssign(&n.Int32, value)
+       err := convertAssign(&n.Int32, value)
+       n.Valid = err == nil
+       return err
 }
 
 // Value implements the [driver.Valuer] interface.
@@ -334,8 +337,9 @@ func (n *NullFloat64) Scan(value any) error {
                n.Float64, n.Valid = 0, false
                return nil
        }
-       n.Valid = true
-       return convertAssign(&n.Float64, value)
+       err := convertAssign(&n.Float64, value)
+       n.Valid = err == nil
+       return err
 }
 
 // Value implements the [driver.Valuer] interface.
@@ -360,8 +364,9 @@ func (n *NullBool) Scan(value any) error {
                n.Bool, n.Valid = false, false
                return nil
        }
-       n.Valid = true
-       return convertAssign(&n.Bool, value)
+       err := convertAssign(&n.Bool, value)
+       n.Valid = err == nil
+       return err
 }
 
 // Value implements the [driver.Valuer] interface.
@@ -386,8 +391,9 @@ func (n *NullTime) Scan(value any) error {
                n.Time, n.Valid = time.Time{}, false
                return nil
        }
-       n.Valid = true
-       return convertAssign(&n.Time, value)
+       err := convertAssign(&n.Time, value)
+       n.Valid = err == nil
+       return err
 }
 
 // Value implements the [driver.Valuer] interface.
@@ -422,8 +428,9 @@ func (n *Null[T]) Scan(value any) error {
                n.V, n.Valid = *new(T), false
                return nil
        }
-       n.Valid = true
-       return convertAssign(&n.V, value)
+       err := convertAssign(&n.V, value)
+       n.Valid = err == nil
+       return err
 }
 
 func (n Null[T]) Value() (driver.Value, error) {
index e8a65600973ba72b3960bb2a193999f414f3f2fa..5f093a2d6de077cac24171a19304085931c2acf4 100644 (file)
@@ -5086,3 +5086,152 @@ type unknownInputsValueConverter struct{}
 func (unknownInputsValueConverter) ConvertValue(v any) (driver.Value, error) {
        return "string", nil
 }
+
+func TestNullTypeScanErrorConsistency(t *testing.T) {
+       // Issue #45662: Null* types should have Valid=false when Scan returns an error.
+       // Previously, Valid was set to true before convertAssign was called,
+       // so if conversion failed, Valid would still be true despite the error.
+
+       tests := []struct {
+               name    string
+               scanner Scanner
+               input   any
+               wantErr bool
+       }{
+               {
+                       name:    "NullInt32 with invalid input",
+                       scanner: &NullInt32{},
+                       input:   []byte("not_a_number"),
+                       wantErr: true,
+               },
+               {
+                       name:    "NullInt64 with invalid input",
+                       scanner: &NullInt64{},
+                       input:   []byte("not_a_number"),
+                       wantErr: true,
+               },
+               {
+                       name:    "NullFloat64 with invalid input",
+                       scanner: &NullFloat64{},
+                       input:   []byte("not_a_float"),
+                       wantErr: true,
+               },
+               {
+                       name:    "NullBool with invalid input",
+                       scanner: &NullBool{},
+                       input:   []byte("not_a_bool"),
+                       wantErr: true,
+               },
+               // Valid cases should still work
+               {
+                       name:    "NullInt32 with valid input",
+                       scanner: &NullInt32{},
+                       input:   int64(42),
+                       wantErr: false,
+               },
+               {
+                       name:    "NullInt64 with valid input",
+                       scanner: &NullInt64{},
+                       input:   int64(42),
+                       wantErr: false,
+               },
+               {
+                       name:    "NullFloat64 with valid input",
+                       scanner: &NullFloat64{},
+                       input:   float64(3.14),
+                       wantErr: false,
+               },
+               {
+                       name:    "NullBool with valid input",
+                       scanner: &NullBool{},
+                       input:   true,
+                       wantErr: false,
+               },
+               {
+                       name:    "NullString with valid input",
+                       scanner: &NullString{},
+                       input:   "hello",
+                       wantErr: false,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       err := tt.scanner.Scan(tt.input)
+
+                       // Check that error matches expectation
+                       if (err != nil) != tt.wantErr {
+                               t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr)
+                       }
+
+                       // The key invariant: Valid should be the opposite of whether we got an error
+                       // (assuming non-nil input)
+                       var valid bool
+                       switch s := tt.scanner.(type) {
+                       case *NullInt32:
+                               valid = s.Valid
+                       case *NullInt64:
+                               valid = s.Valid
+                       case *NullFloat64:
+                               valid = s.Valid
+                       case *NullBool:
+                               valid = s.Valid
+                       case *NullString:
+                               valid = s.Valid
+                       case *NullTime:
+                               valid = s.Valid
+                       }
+
+                       if err != nil && valid {
+                               t.Errorf("Scan() returned error but Valid=true; want Valid=false when err!=nil")
+                       }
+                       if err == nil && !valid {
+                               t.Errorf("Scan() returned nil error but Valid=false; want Valid=true when err==nil")
+                       }
+               })
+       }
+}
+
+// TestNullTypeScanNil verifies that scanning nil sets Valid=false without error.
+func TestNullTypeScanNil(t *testing.T) {
+       tests := []struct {
+               name    string
+               scanner Scanner
+       }{
+               {"NullString", &NullString{String: "preset", Valid: true}},
+               {"NullInt64", &NullInt64{Int64: 42, Valid: true}},
+               {"NullInt32", &NullInt32{Int32: 42, Valid: true}},
+               {"NullFloat64", &NullFloat64{Float64: 3.14, Valid: true}},
+               {"NullBool", &NullBool{Bool: true, Valid: true}},
+               {"NullTime", &NullTime{Time: time.Now(), Valid: true}},
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       err := tt.scanner.Scan(nil)
+                       if err != nil {
+                               t.Errorf("Scan(nil) error = %v; want nil", err)
+                       }
+
+                       var valid bool
+                       switch s := tt.scanner.(type) {
+                       case *NullString:
+                               valid = s.Valid
+                       case *NullInt64:
+                               valid = s.Valid
+                       case *NullInt32:
+                               valid = s.Valid
+                       case *NullFloat64:
+                               valid = s.Valid
+                       case *NullBool:
+                               valid = s.Valid
+                       case *NullTime:
+                               valid = s.Valid
+                       }
+
+                       if valid {
+                               t.Errorf("Scan(nil) left Valid=true; want Valid=false")
+                       }
+               })
+       }
+}