]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: use slices to simplify the code
authorapocelipes <seve3r@outlook.com>
Fri, 29 Mar 2024 20:09:37 +0000 (20:09 +0000)
committerGopher Robot <gobot@golang.org>
Mon, 1 Apr 2024 12:38:07 +0000 (12:38 +0000)
Change-Id: Ia198272330626271ee7d4e1ae46afca819ab2933
GitHub-Last-Rev: e713ac31638671f60cc3cf62fa514f784e834e66
GitHub-Pull-Request: golang/go#66572
Reviewed-on: https://go-review.googlesource.com/c/go/+/574995
Reviewed-by: Emmanuel Odeke <emmanuel@orijtech.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Reviewed-by: Than McIntosh <thanm@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: qiulaidongfeng <2645477756@qq.com>
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index c6c3172b5c34151b635521ceee0d350a9a5df9c1..3dfcd447b52bca1234c2407a7ce62c017d46015f 100644 (file)
@@ -11,7 +11,7 @@ import (
        "fmt"
        "io"
        "reflect"
-       "sort"
+       "slices"
        "strconv"
        "strings"
        "sync"
@@ -120,12 +120,7 @@ type table struct {
 }
 
 func (t *table) columnIndex(name string) int {
-       for n, nname := range t.colname {
-               if name == nname {
-                       return n
-               }
-       }
-       return -1
+       return slices.Index(t.colname, name)
 }
 
 type row struct {
@@ -217,15 +212,6 @@ func init() {
        Register("test", fdriver)
 }
 
-func contains(list []string, y string) bool {
-       for _, x := range list {
-               if x == y {
-                       return true
-               }
-       }
-       return false
-}
-
 type Dummy struct {
        driver.Driver
 }
@@ -235,7 +221,7 @@ func TestDrivers(t *testing.T) {
        Register("test", fdriver)
        Register("invalid", Dummy{})
        all := Drivers()
-       if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
+       if len(all) < 2 || !slices.IsSorted(all) || !slices.Contains(all, "test") || !slices.Contains(all, "invalid") {
                t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
        }
 }
@@ -345,10 +331,8 @@ func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
        if !ok {
                return
        }
-       for n, cname := range t.colname {
-               if cname == column {
-                       return t.coltype[n], true
-               }
+       if i := slices.Index(t.colname, column); i != -1 {
+               return t.coltype[i], true
        }
        return "", false
 }
@@ -823,6 +807,15 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
        return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
 }
 
+func valueFromPlaceholderName(args []driver.NamedValue, name string) driver.Value {
+       for i := range args {
+               if args[i].Name == name {
+                       return args[i].Value
+               }
+       }
+       return nil
+}
+
 // When doInsert is true, add the row to the table.
 // When doInsert is false do prep-work and error checking, but don't
 // actually add the row to the table.
@@ -857,11 +850,8 @@ func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.R
                                val = args[argPos].Value
                        } else {
                                // Assign value from argument placeholder name.
-                               for _, a := range args {
-                                       if a.Name == strvalue[1:] {
-                                               val = a.Value
-                                               break
-                                       }
+                               if v := valueFromPlaceholderName(args, strvalue[1:]); v != nil {
+                                       val = v
                                }
                        }
                        argPos++
@@ -997,12 +987,8 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
                                if wcol.Placeholder == "?" {
                                        argValue = args[wcol.Ordinal-1].Value
                                } else {
-                                       // Assign arg value from placeholder name.
-                                       for _, a := range args {
-                                               if a.Name == wcol.Placeholder[1:] {
-                                                       argValue = a.Value
-                                                       break
-                                               }
+                                       if v := valueFromPlaceholderName(args, wcol.Placeholder[1:]); v != nil {
+                                               argValue = v
                                        }
                                }
                                if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
index 36995a1059bd6db429a04cc7e19c03855647ab80..5b4a3f540935863ccbd7611ca103e8d288d80ef1 100644 (file)
@@ -24,7 +24,7 @@ import (
        "math/rand/v2"
        "reflect"
        "runtime"
-       "sort"
+       "slices"
        "strconv"
        "sync"
        "sync/atomic"
@@ -69,7 +69,7 @@ func Drivers() []string {
        for name := range drivers {
                list = append(list, name)
        }
-       sort.Strings(list)
+       slices.Sort(list)
        return list
 }
 
@@ -3452,10 +3452,8 @@ func (r *Row) Scan(dest ...any) error {
        // they were obtained from the network anyway) But for now we
        // don't care.
        defer r.rows.Close()
-       for _, dp := range dest {
-               if _, ok := dp.(*RawBytes); ok {
-                       return errors.New("sql: RawBytes isn't allowed on Row.Scan")
-               }
+       if scanArgsContainRawBytes(dest) {
+               return errors.New("sql: RawBytes isn't allowed on Row.Scan")
        }
 
        if !r.rows.Next() {
index 7bf3ebbe08377fcb2111dc6a9325f26597a4ad23..25ca5ff0adaed1c4a1bcbfc61624330174d0ae25 100644 (file)
@@ -40,14 +40,7 @@ func init() {
                freedFrom[c] = s
        }
        putConnHook = func(db *DB, c *driverConn) {
-               idx := -1
-               for i, v := range db.freeConn {
-                       if v == c {
-                               idx = i
-                               break
-                       }
-               }
-               if idx >= 0 {
+               if slices.Contains(db.freeConn, c) {
                        // print before panic, as panic may get lost due to conflicting panic
                        // (all goroutines asleep) elsewhere, since we might not unlock
                        // the mutex in freeConn here.
@@ -291,7 +284,7 @@ func TestQuery(t *testing.T) {
                {age: 2, name: "Bob"},
                {age: 3, name: "Chris"},
        }
-       if !reflect.DeepEqual(got, want) {
+       if !slices.Equal(got, want) {
                t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
        }
 
@@ -355,7 +348,7 @@ func TestQueryContext(t *testing.T) {
                {age: 1, name: "Alice"},
                {age: 2, name: "Bob"},
        }
-       if !reflect.DeepEqual(got, want) {
+       if !slices.Equal(got, want) {
                t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
        }
 
@@ -540,7 +533,7 @@ func TestMultiResultSetQuery(t *testing.T) {
                {age: 2, name: "Bob"},
                {age: 3, name: "Chris"},
        }
-       if !reflect.DeepEqual(got1, want1) {
+       if !slices.Equal(got1, want1) {
                t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
        }
 
@@ -566,7 +559,7 @@ func TestMultiResultSetQuery(t *testing.T) {
                {name: "Bob"},
                {name: "Chris"},
        }
-       if !reflect.DeepEqual(got2, want2) {
+       if !slices.Equal(got2, want2) {
                t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
        }
        if rows.NextResultSet() {
@@ -614,7 +607,7 @@ func TestQueryNamedArg(t *testing.T) {
        want := []row{
                {age: 2, name: "Bob"},
        }
-       if !reflect.DeepEqual(got, want) {
+       if !slices.Equal(got, want) {
                t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
        }
 
@@ -724,7 +717,7 @@ func TestRowsColumns(t *testing.T) {
                t.Fatalf("Columns: %v", err)
        }
        want := []string{"age", "name"}
-       if !reflect.DeepEqual(cols, want) {
+       if !slices.Equal(cols, want) {
                t.Errorf("got %#v; want %#v", cols, want)
        }
        if err := rows.Close(); err != nil {
@@ -827,7 +820,7 @@ func TestQueryRow(t *testing.T) {
                t.Fatalf("photo QueryRow+Scan: %v", err)
        }
        want := []byte("APHOTO")
-       if !reflect.DeepEqual(photo, want) {
+       if !slices.Equal(photo, want) {
                t.Errorf("photo = %q; want %q", photo, want)
        }
 }