]> Cypherpunks repositories - gostls13.git/commitdiff
text/template: make reflect.Value indirections more robust
authorDaniel Martí <mvdan@mvdan.cc>
Thu, 19 Dec 2019 14:13:59 +0000 (14:13 +0000)
committerDaniel Martí <mvdan@mvdan.cc>
Mon, 24 Feb 2020 09:16:18 +0000 (09:16 +0000)
Always shadow or modify the original parameter name. With code like:

func index(item reflect.Value, ... {
v := indirectInterface(item)

It was possible to incorrectly use 'item' and 'v' later in the function,
which could result in subtle bugs. This is precisely the kind of mistake
that led to #36199.

Instead, don't keep both the old and new reflect.Value variables in
scope. Always shadow or modify the original variable.

While at it, simplify the signature of 'length', to receive a
reflect.Value directly and save a few redundant lines.

Change-Id: I01416636a9d49f81246d28b91aca6413b1ba1aa5
Reviewed-on: https://go-review.googlesource.com/c/go/+/212117
Run-TryBot: Daniel Martí <mvdan@mvdan.cc>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Roberto Clapis <robclap8@gmail.com>
Reviewed-by: Rob Pike <r@golang.org>
src/text/template/funcs.go

index 46125bc2168b9dffc1545d382b2b9350074706f8..6a6843dfa08f99ee78879557b12d9699d3b78b95 100644 (file)
@@ -185,41 +185,41 @@ func indexArg(index reflect.Value, cap int) (int, error) {
 // arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
 // indexed item must be a map, slice, or array.
 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
-       v := indirectInterface(item)
-       if !v.IsValid() {
+       item = indirectInterface(item)
+       if !item.IsValid() {
                return reflect.Value{}, fmt.Errorf("index of untyped nil")
        }
-       for _, i := range indexes {
-               index := indirectInterface(i)
+       for _, index := range indexes {
+               index = indirectInterface(index)
                var isNil bool
-               if v, isNil = indirect(v); isNil {
+               if item, isNil = indirect(item); isNil {
                        return reflect.Value{}, fmt.Errorf("index of nil pointer")
                }
-               switch v.Kind() {
+               switch item.Kind() {
                case reflect.Array, reflect.Slice, reflect.String:
-                       x, err := indexArg(index, v.Len())
+                       x, err := indexArg(index, item.Len())
                        if err != nil {
                                return reflect.Value{}, err
                        }
-                       v = v.Index(x)
+                       item = item.Index(x)
                case reflect.Map:
-                       index, err := prepareArg(index, v.Type().Key())
+                       index, err := prepareArg(index, item.Type().Key())
                        if err != nil {
                                return reflect.Value{}, err
                        }
-                       if x := v.MapIndex(index); x.IsValid() {
-                               v = x
+                       if x := item.MapIndex(index); x.IsValid() {
+                               item = x
                        } else {
-                               v = reflect.Zero(v.Type().Elem())
+                               item = reflect.Zero(item.Type().Elem())
                        }
                case reflect.Invalid:
-                       // the loop holds invariant: v.IsValid()
+                       // the loop holds invariant: item.IsValid()
                        panic("unreachable")
                default:
-                       return reflect.Value{}, fmt.Errorf("can't index item of type %s", v.Type())
+                       return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
                }
        }
-       return v, nil
+       return item, nil
 }
 
 // Slicing.
@@ -229,29 +229,27 @@ func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error)
 // is x[:], "slice x 1" is x[1:], and "slice x 1 2 3" is x[1:2:3]. The first
 // argument must be a string, slice, or array.
 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
-       var (
-               cap int
-               v   = indirectInterface(item)
-       )
-       if !v.IsValid() {
+       item = indirectInterface(item)
+       if !item.IsValid() {
                return reflect.Value{}, fmt.Errorf("slice of untyped nil")
        }
        if len(indexes) > 3 {
                return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
        }
-       switch v.Kind() {
+       var cap int
+       switch item.Kind() {
        case reflect.String:
                if len(indexes) == 3 {
                        return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
                }
-               cap = v.Len()
+               cap = item.Len()
        case reflect.Array, reflect.Slice:
-               cap = v.Cap()
+               cap = item.Cap()
        default:
-               return reflect.Value{}, fmt.Errorf("can't slice item of type %s", v.Type())
+               return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
        }
        // set default values for cases item[:], item[i:].
-       idx := [3]int{0, v.Len()}
+       idx := [3]int{0, item.Len()}
        for i, index := range indexes {
                x, err := indexArg(index, cap)
                if err != nil {
@@ -264,32 +262,28 @@ func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error)
                return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
        }
        if len(indexes) < 3 {
-               return v.Slice(idx[0], idx[1]), nil
+               return item.Slice(idx[0], idx[1]), nil
        }
        // given item[i:j:k], make sure i <= j <= k.
        if idx[1] > idx[2] {
                return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
        }
-       return v.Slice3(idx[0], idx[1], idx[2]), nil
+       return item.Slice3(idx[0], idx[1], idx[2]), nil
 }
 
 // Length
 
 // length returns the length of the item, with an error if it has no defined length.
-func length(item interface{}) (int, error) {
-       v := reflect.ValueOf(item)
-       if !v.IsValid() {
-               return 0, fmt.Errorf("len of untyped nil")
-       }
-       v, isNil := indirect(v)
+func length(item reflect.Value) (int, error) {
+       item, isNil := indirect(item)
        if isNil {
                return 0, fmt.Errorf("len of nil pointer")
        }
-       switch v.Kind() {
+       switch item.Kind() {
        case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
-               return v.Len(), nil
+               return item.Len(), nil
        }
-       return 0, fmt.Errorf("len of type %s", v.Type())
+       return 0, fmt.Errorf("len of type %s", item.Type())
 }
 
 // Function invocation
@@ -297,11 +291,11 @@ func length(item interface{}) (int, error) {
 // call returns the result of evaluating the first argument as a function.
 // The function must return 1 result, or 2 results, the second of which is an error.
 func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
-       v := indirectInterface(fn)
-       if !v.IsValid() {
+       fn = indirectInterface(fn)
+       if !fn.IsValid() {
                return reflect.Value{}, fmt.Errorf("call of nil")
        }
-       typ := v.Type()
+       typ := fn.Type()
        if typ.Kind() != reflect.Func {
                return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
        }
@@ -322,7 +316,7 @@ func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
        }
        argv := make([]reflect.Value, len(args))
        for i, arg := range args {
-               value := indirectInterface(arg)
+               arg = indirectInterface(arg)
                // Compute the expected type. Clumsy because of variadics.
                argType := dddType
                if !typ.IsVariadic() || i < numIn-1 {
@@ -330,11 +324,11 @@ func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
                }
 
                var err error
-               if argv[i], err = prepareArg(value, argType); err != nil {
+               if argv[i], err = prepareArg(arg, argType); err != nil {
                        return reflect.Value{}, fmt.Errorf("arg %d: %s", i, err)
                }
        }
-       return safeCall(v, argv)
+       return safeCall(fn, argv)
 }
 
 // safeCall runs fun.Call(args), and returns the resulting value and error, if
@@ -440,52 +434,52 @@ func basicKind(v reflect.Value) (kind, error) {
 
 // eq evaluates the comparison a == b || a == c || ...
 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
-       v1 := indirectInterface(arg1)
-       if v1 != zero {
-               if t1 := v1.Type(); !t1.Comparable() {
-                       return false, fmt.Errorf("uncomparable type %s: %v", t1, v1)
+       arg1 = indirectInterface(arg1)
+       if arg1 != zero {
+               if t1 := arg1.Type(); !t1.Comparable() {
+                       return false, fmt.Errorf("uncomparable type %s: %v", t1, arg1)
                }
        }
        if len(arg2) == 0 {
                return false, errNoComparison
        }
-       k1, _ := basicKind(v1)
+       k1, _ := basicKind(arg1)
        for _, arg := range arg2 {
-               v2 := indirectInterface(arg)
-               k2, _ := basicKind(v2)
+               arg = indirectInterface(arg)
+               k2, _ := basicKind(arg)
                truth := false
                if k1 != k2 {
                        // Special case: Can compare integer values regardless of type's sign.
                        switch {
                        case k1 == intKind && k2 == uintKind:
-                               truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
+                               truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
                        case k1 == uintKind && k2 == intKind:
-                               truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
+                               truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
                        default:
                                return false, errBadComparison
                        }
                } else {
                        switch k1 {
                        case boolKind:
-                               truth = v1.Bool() == v2.Bool()
+                               truth = arg1.Bool() == arg.Bool()
                        case complexKind:
-                               truth = v1.Complex() == v2.Complex()
+                               truth = arg1.Complex() == arg.Complex()
                        case floatKind:
-                               truth = v1.Float() == v2.Float()
+                               truth = arg1.Float() == arg.Float()
                        case intKind:
-                               truth = v1.Int() == v2.Int()
+                               truth = arg1.Int() == arg.Int()
                        case stringKind:
-                               truth = v1.String() == v2.String()
+                               truth = arg1.String() == arg.String()
                        case uintKind:
-                               truth = v1.Uint() == v2.Uint()
+                               truth = arg1.Uint() == arg.Uint()
                        default:
-                               if v2 == zero {
-                                       truth = v1 == v2
+                               if arg == zero {
+                                       truth = arg1 == arg
                                } else {
-                                       if t2 := v2.Type(); !t2.Comparable() {
-                                               return false, fmt.Errorf("uncomparable type %s: %v", t2, v2)
+                                       if t2 := arg.Type(); !t2.Comparable() {
+                                               return false, fmt.Errorf("uncomparable type %s: %v", t2, arg)
                                        }
-                                       truth = v1.Interface() == v2.Interface()
+                                       truth = arg1.Interface() == arg.Interface()
                                }
                        }
                }
@@ -505,13 +499,13 @@ func ne(arg1, arg2 reflect.Value) (bool, error) {
 
 // lt evaluates the comparison a < b.
 func lt(arg1, arg2 reflect.Value) (bool, error) {
-       v1 := indirectInterface(arg1)
-       k1, err := basicKind(v1)
+       arg1 = indirectInterface(arg1)
+       k1, err := basicKind(arg1)
        if err != nil {
                return false, err
        }
-       v2 := indirectInterface(arg2)
-       k2, err := basicKind(v2)
+       arg2 = indirectInterface(arg2)
+       k2, err := basicKind(arg2)
        if err != nil {
                return false, err
        }
@@ -520,9 +514,9 @@ func lt(arg1, arg2 reflect.Value) (bool, error) {
                // Special case: Can compare integer values regardless of type's sign.
                switch {
                case k1 == intKind && k2 == uintKind:
-                       truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
+                       truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
                case k1 == uintKind && k2 == intKind:
-                       truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
+                       truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
                default:
                        return false, errBadComparison
                }
@@ -531,13 +525,13 @@ func lt(arg1, arg2 reflect.Value) (bool, error) {
                case boolKind, complexKind:
                        return false, errBadComparisonType
                case floatKind:
-                       truth = v1.Float() < v2.Float()
+                       truth = arg1.Float() < arg2.Float()
                case intKind:
-                       truth = v1.Int() < v2.Int()
+                       truth = arg1.Int() < arg2.Int()
                case stringKind:
-                       truth = v1.String() < v2.String()
+                       truth = arg1.String() < arg2.String()
                case uintKind:
-                       truth = v1.Uint() < v2.Uint()
+                       truth = arg1.Uint() < arg2.Uint()
                default:
                        panic("invalid kind")
                }