]> Cypherpunks repositories - gostls13.git/commitdiff
text/template: permit eq and ne funcs to check against nil
authorRob Pike <r@golang.org>
Mon, 14 Mar 2022 02:21:06 +0000 (13:21 +1100)
committerIan Lance Taylor <iant@golang.org>
Mon, 4 Apr 2022 17:28:30 +0000 (17:28 +0000)
The existing code errors out immediately if the argument is not
"comparable", making it impossible to test a slice, map, and so
on from being compared to nil.

Fix by delaying the "comparable" error check until we encounter
an actual check between two non-comparable, non-nil values.

Note for the future: reflect makes it unnecessarily clumsy
to deal with nil values in cases like this. For instance, it
should be possible to check if a value is nil without stepping
around a panic. See the new functions isNil and canCompare
for my (too expensive) workaround.

Fixes #51642

Change-Id: Ic4072698c4910130ea7e3d76e7a148d8a8b88162
Reviewed-on: https://go-review.googlesource.com/c/go/+/392274
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Trust: Cherry Mui <cherryyz@google.com>

src/html/template/exec_test.go
src/text/template/exec_test.go
src/text/template/funcs.go

index 6cf710efabc4f49eba8152c4cff41993b42a2dd0..f042cf5125f0f43735f0ff8044f14b5399c13468 100644 (file)
@@ -1191,15 +1191,19 @@ var cmpTests = []cmpTest{
        {"eq .Iface1 .Iface1", "true", true},
        {"eq .Iface1 .Iface2", "false", true},
        {"eq .Iface2 .Iface2", "true", true},
+       {"eq .Map .Map", "true", true},        // Uncomparable types but nil is OK.
+       {"eq .Map nil", "true", true},         // Uncomparable types but nil is OK.
+       {"eq nil .Map", "true", true},         // Uncomparable types but nil is OK.
+       {"eq .Map .NonNilMap", "false", true}, // Uncomparable types but nil is OK.
        // Errors
-       {"eq `xy` 1", "", false},       // Different types.
-       {"eq 2 2.0", "", false},        // Different types.
-       {"lt true true", "", false},    // Unordered types.
-       {"lt 1+0i 1+0i", "", false},    // Unordered types.
-       {"eq .Ptr 1", "", false},       // Incompatible types.
-       {"eq .Ptr .NegOne", "", false}, // Incompatible types.
-       {"eq .Map .Map", "", false},    // Uncomparable types.
-       {"eq .Map .V1", "", false},     // Uncomparable types.
+       {"eq `xy` 1", "", false},                // Different types.
+       {"eq 2 2.0", "", false},                 // Different types.
+       {"lt true true", "", false},             // Unordered types.
+       {"lt 1+0i 1+0i", "", false},             // Unordered types.
+       {"eq .Ptr 1", "", false},                // Incompatible types.
+       {"eq .Ptr .NegOne", "", false},          // Incompatible types.
+       {"eq .Map .V1", "", false},              // Uncomparable types.
+       {"eq .NonNilMap .NonNilMap", "", false}, // Uncomparable types.
 }
 
 func TestComparison(t *testing.T) {
@@ -1208,16 +1212,18 @@ func TestComparison(t *testing.T) {
                Uthree, Ufour  uint
                NegOne, Three  int
                Ptr, NilPtr    *int
+               NonNilMap      map[int]int
                Map            map[int]int
                V1, V2         V
                Iface1, Iface2 fmt.Stringer
        }{
-               Uthree: 3,
-               Ufour:  4,
-               NegOne: -1,
-               Three:  3,
-               Ptr:    new(int),
-               Iface1: b,
+               Uthree:    3,
+               Ufour:     4,
+               NegOne:    -1,
+               Three:     3,
+               Ptr:       new(int),
+               NonNilMap: make(map[int]int),
+               Iface1:    b,
        }
        for _, test := range cmpTests {
                text := fmt.Sprintf("{{if %s}}true{{else}}false{{end}}", test.expr)
index 8c8143396dbefe37c067fa3b9f0d796b0bedef02..56566b920feae2e1d2a7cd7dc0c3511e35614ec7 100644 (file)
@@ -1220,15 +1220,19 @@ var cmpTests = []cmpTest{
        {"eq .NilIface .Iface1", "false", true},
        {"eq .NilIface 0", "false", true},
        {"eq 0 .NilIface", "false", true},
+       {"eq .Map .Map", "true", true},        // Uncomparable types but nil is OK.
+       {"eq .Map nil", "true", true},         // Uncomparable types but nil is OK.
+       {"eq nil .Map", "true", true},         // Uncomparable types but nil is OK.
+       {"eq .Map .NonNilMap", "false", true}, // Uncomparable types but nil is OK.
        // Errors
-       {"eq `xy` 1", "", false},       // Different types.
-       {"eq 2 2.0", "", false},        // Different types.
-       {"lt true true", "", false},    // Unordered types.
-       {"lt 1+0i 1+0i", "", false},    // Unordered types.
-       {"eq .Ptr 1", "", false},       // Incompatible types.
-       {"eq .Ptr .NegOne", "", false}, // Incompatible types.
-       {"eq .Map .Map", "", false},    // Uncomparable types.
-       {"eq .Map .V1", "", false},     // Uncomparable types.
+       {"eq `xy` 1", "", false},                // Different types.
+       {"eq 2 2.0", "", false},                 // Different types.
+       {"lt true true", "", false},             // Unordered types.
+       {"lt 1+0i 1+0i", "", false},             // Unordered types.
+       {"eq .Ptr 1", "", false},                // Incompatible types.
+       {"eq .Ptr .NegOne", "", false},          // Incompatible types.
+       {"eq .Map .V1", "", false},              // Uncomparable types.
+       {"eq .NonNilMap .NonNilMap", "", false}, // Uncomparable types.
 }
 
 func TestComparison(t *testing.T) {
@@ -1237,16 +1241,18 @@ func TestComparison(t *testing.T) {
                Uthree, Ufour    uint
                NegOne, Three    int
                Ptr, NilPtr      *int
+               NonNilMap        map[int]int
                Map              map[int]int
                V1, V2           V
                Iface1, NilIface fmt.Stringer
        }{
-               Uthree: 3,
-               Ufour:  4,
-               NegOne: -1,
-               Three:  3,
-               Ptr:    new(int),
-               Iface1: b,
+               Uthree:    3,
+               Ufour:     4,
+               NegOne:    -1,
+               Three:     3,
+               Ptr:       new(int),
+               NonNilMap: make(map[int]int),
+               Iface1:    b,
        }
        for _, test := range cmpTests {
                text := fmt.Sprintf("{{if %s}}true{{else}}false{{end}}", test.expr)
index dca5ed28db84ff8f03a927d91c8980deb6ff63dd..1f63b361f88bb045da25555e1e02521955a5b667 100644 (file)
@@ -436,14 +436,33 @@ func basicKind(v reflect.Value) (kind, error) {
        return invalidKind, errBadComparisonType
 }
 
+// isNil returns true if v is the zero reflect.Value, or nil of its type.
+func isNil(v reflect.Value) bool {
+       if v == zero {
+               return true
+       }
+       switch v.Kind() {
+       case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
+               return v.IsNil()
+       }
+       return false
+}
+
+// canCompare reports whether v1 and v2 are both the same kind, or one is nil.
+// Called only when dealing with nillable types, or there's about to be an error.
+func canCompare(v1, v2 reflect.Value) bool {
+       k1 := v1.Kind()
+       k2 := v2.Kind()
+       if k1 == k2 {
+               return true
+       }
+       // We know the type can be compared to nil.
+       return k1 == reflect.Invalid || k2 == reflect.Invalid
+}
+
 // eq evaluates the comparison a == b || a == c || ...
 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
        arg1 = indirectInterface(arg1)
-       if arg1 != zero {
-               if t1 := arg1.Type(); !t1.Comparable() {
-                       return false, fmt.Errorf("uncomparable type %s: %v", t1, arg1)
-               }
-       }
        if len(arg2) == 0 {
                return false, errNoComparison
        }
@@ -479,11 +498,14 @@ func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
                        case uintKind:
                                truth = arg1.Uint() == arg.Uint()
                        default:
-                               if arg == zero || arg1 == zero {
-                                       truth = arg1 == arg
+                               if !canCompare(arg1, arg) {
+                                       return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
+                               }
+                               if isNil(arg1) || isNil(arg) {
+                                       truth = isNil(arg) == isNil(arg1)
                                } else {
-                                       if t2 := arg.Type(); !t2.Comparable() {
-                                               return false, fmt.Errorf("uncomparable type %s: %v", t2, arg)
+                                       if !arg.Type().Comparable() {
+                                               return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
                                        }
                                        truth = arg1.Interface() == arg.Interface()
                                }