From c58f1bb65f2187d79a5842bb19f4db4cafd22794 Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Mon, 14 Mar 2022 13:21:06 +1100 Subject: [PATCH] text/template: permit eq and ne funcs to check against nil The existing code errors out immediately if the argument is not "comparable", making it impossible to test a slice, map, and so on from being compared to nil. Fix by delaying the "comparable" error check until we encounter an actual check between two non-comparable, non-nil values. Note for the future: reflect makes it unnecessarily clumsy to deal with nil values in cases like this. For instance, it should be possible to check if a value is nil without stepping around a panic. See the new functions isNil and canCompare for my (too expensive) workaround. Fixes #51642 Change-Id: Ic4072698c4910130ea7e3d76e7a148d8a8b88162 Reviewed-on: https://go-review.googlesource.com/c/go/+/392274 Reviewed-by: Ian Lance Taylor Run-TryBot: Ian Lance Taylor TryBot-Result: Gopher Robot Trust: Cherry Mui --- src/html/template/exec_test.go | 34 +++++++++++++++++------------ src/text/template/exec_test.go | 34 +++++++++++++++++------------ src/text/template/funcs.go | 40 ++++++++++++++++++++++++++-------- 3 files changed, 71 insertions(+), 37 deletions(-) diff --git a/src/html/template/exec_test.go b/src/html/template/exec_test.go index 6cf710efab..f042cf5125 100644 --- a/src/html/template/exec_test.go +++ b/src/html/template/exec_test.go @@ -1191,15 +1191,19 @@ var cmpTests = []cmpTest{ {"eq .Iface1 .Iface1", "true", true}, {"eq .Iface1 .Iface2", "false", true}, {"eq .Iface2 .Iface2", "true", true}, + {"eq .Map .Map", "true", true}, // Uncomparable types but nil is OK. + {"eq .Map nil", "true", true}, // Uncomparable types but nil is OK. + {"eq nil .Map", "true", true}, // Uncomparable types but nil is OK. + {"eq .Map .NonNilMap", "false", true}, // Uncomparable types but nil is OK. // Errors - {"eq `xy` 1", "", false}, // Different types. - {"eq 2 2.0", "", false}, // Different types. - {"lt true true", "", false}, // Unordered types. - {"lt 1+0i 1+0i", "", false}, // Unordered types. - {"eq .Ptr 1", "", false}, // Incompatible types. - {"eq .Ptr .NegOne", "", false}, // Incompatible types. - {"eq .Map .Map", "", false}, // Uncomparable types. - {"eq .Map .V1", "", false}, // Uncomparable types. + {"eq `xy` 1", "", false}, // Different types. + {"eq 2 2.0", "", false}, // Different types. + {"lt true true", "", false}, // Unordered types. + {"lt 1+0i 1+0i", "", false}, // Unordered types. + {"eq .Ptr 1", "", false}, // Incompatible types. + {"eq .Ptr .NegOne", "", false}, // Incompatible types. + {"eq .Map .V1", "", false}, // Uncomparable types. + {"eq .NonNilMap .NonNilMap", "", false}, // Uncomparable types. } func TestComparison(t *testing.T) { @@ -1208,16 +1212,18 @@ func TestComparison(t *testing.T) { Uthree, Ufour uint NegOne, Three int Ptr, NilPtr *int + NonNilMap map[int]int Map map[int]int V1, V2 V Iface1, Iface2 fmt.Stringer }{ - Uthree: 3, - Ufour: 4, - NegOne: -1, - Three: 3, - Ptr: new(int), - Iface1: b, + Uthree: 3, + Ufour: 4, + NegOne: -1, + Three: 3, + Ptr: new(int), + NonNilMap: make(map[int]int), + Iface1: b, } for _, test := range cmpTests { text := fmt.Sprintf("{{if %s}}true{{else}}false{{end}}", test.expr) diff --git a/src/text/template/exec_test.go b/src/text/template/exec_test.go index 8c8143396d..56566b920f 100644 --- a/src/text/template/exec_test.go +++ b/src/text/template/exec_test.go @@ -1220,15 +1220,19 @@ var cmpTests = []cmpTest{ {"eq .NilIface .Iface1", "false", true}, {"eq .NilIface 0", "false", true}, {"eq 0 .NilIface", "false", true}, + {"eq .Map .Map", "true", true}, // Uncomparable types but nil is OK. + {"eq .Map nil", "true", true}, // Uncomparable types but nil is OK. + {"eq nil .Map", "true", true}, // Uncomparable types but nil is OK. + {"eq .Map .NonNilMap", "false", true}, // Uncomparable types but nil is OK. // Errors - {"eq `xy` 1", "", false}, // Different types. - {"eq 2 2.0", "", false}, // Different types. - {"lt true true", "", false}, // Unordered types. - {"lt 1+0i 1+0i", "", false}, // Unordered types. - {"eq .Ptr 1", "", false}, // Incompatible types. - {"eq .Ptr .NegOne", "", false}, // Incompatible types. - {"eq .Map .Map", "", false}, // Uncomparable types. - {"eq .Map .V1", "", false}, // Uncomparable types. + {"eq `xy` 1", "", false}, // Different types. + {"eq 2 2.0", "", false}, // Different types. + {"lt true true", "", false}, // Unordered types. + {"lt 1+0i 1+0i", "", false}, // Unordered types. + {"eq .Ptr 1", "", false}, // Incompatible types. + {"eq .Ptr .NegOne", "", false}, // Incompatible types. + {"eq .Map .V1", "", false}, // Uncomparable types. + {"eq .NonNilMap .NonNilMap", "", false}, // Uncomparable types. } func TestComparison(t *testing.T) { @@ -1237,16 +1241,18 @@ func TestComparison(t *testing.T) { Uthree, Ufour uint NegOne, Three int Ptr, NilPtr *int + NonNilMap map[int]int Map map[int]int V1, V2 V Iface1, NilIface fmt.Stringer }{ - Uthree: 3, - Ufour: 4, - NegOne: -1, - Three: 3, - Ptr: new(int), - Iface1: b, + Uthree: 3, + Ufour: 4, + NegOne: -1, + Three: 3, + Ptr: new(int), + NonNilMap: make(map[int]int), + Iface1: b, } for _, test := range cmpTests { text := fmt.Sprintf("{{if %s}}true{{else}}false{{end}}", test.expr) diff --git a/src/text/template/funcs.go b/src/text/template/funcs.go index dca5ed28db..1f63b361f8 100644 --- a/src/text/template/funcs.go +++ b/src/text/template/funcs.go @@ -436,14 +436,33 @@ func basicKind(v reflect.Value) (kind, error) { return invalidKind, errBadComparisonType } +// isNil returns true if v is the zero reflect.Value, or nil of its type. +func isNil(v reflect.Value) bool { + if v == zero { + return true + } + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + return v.IsNil() + } + return false +} + +// canCompare reports whether v1 and v2 are both the same kind, or one is nil. +// Called only when dealing with nillable types, or there's about to be an error. +func canCompare(v1, v2 reflect.Value) bool { + k1 := v1.Kind() + k2 := v2.Kind() + if k1 == k2 { + return true + } + // We know the type can be compared to nil. + return k1 == reflect.Invalid || k2 == reflect.Invalid +} + // eq evaluates the comparison a == b || a == c || ... func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) { arg1 = indirectInterface(arg1) - if arg1 != zero { - if t1 := arg1.Type(); !t1.Comparable() { - return false, fmt.Errorf("uncomparable type %s: %v", t1, arg1) - } - } if len(arg2) == 0 { return false, errNoComparison } @@ -479,11 +498,14 @@ func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) { case uintKind: truth = arg1.Uint() == arg.Uint() default: - if arg == zero || arg1 == zero { - truth = arg1 == arg + if !canCompare(arg1, arg) { + return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg) + } + if isNil(arg1) || isNil(arg) { + truth = isNil(arg) == isNil(arg1) } else { - if t2 := arg.Type(); !t2.Comparable() { - return false, fmt.Errorf("uncomparable type %s: %v", t2, arg) + if !arg.Type().Comparable() { + return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type()) } truth = arg1.Interface() == arg.Interface() } -- 2.50.0