]> Cypherpunks repositories - gostls13.git/commitdiff
exp/template: allow fields and methods to be found in parent structs.
authorRob Pike <r@golang.org>
Fri, 8 Jul 2011 05:22:05 +0000 (15:22 +1000)
committerRob Pike <r@golang.org>
Fri, 8 Jul 2011 05:22:05 +0000 (15:22 +1000)
R=golang-dev, adg
CC=golang-dev
https://golang.org/cl/4657085

src/pkg/exp/template/exec.go
src/pkg/exp/template/exec_test.go

index fb0a9e62186ff3d183712a25ffb9c5fedb8eb2b1..09bf8587e08a466892997467fca8744e4fef280f 100644 (file)
@@ -22,8 +22,25 @@ type state struct {
        wr   io.Writer
        set  *Set
        line int // line number for errors
+       // parent holds the state for the surrounding data object,
+       // typically the structure containing the field we are evaluating.
+       parent struct {
+               state *state
+               data  reflect.Value
+       }
+}
+
+// down returns a new state representing a child of the current state.
+// data represents the parent of the child.
+func (s *state) down(data reflect.Value) *state {
+       var child = *s
+       child.parent.state = s
+       child.parent.data = data
+       return &child
 }
 
+var zero reflect.Value
+
 // errorf formats the error and terminates processing.
 func (s *state) errorf(format string, args ...interface{}) {
        format = fmt.Sprintf("template: %s:%d: %s", s.tmpl.name, s.line, format)
@@ -101,9 +118,10 @@ func (s *state) walkIfOrWith(typ nodeType, data reflect.Value, pipe []*commandNo
        }
        if truth {
                if typ == nodeWith {
-                       data = val
+                       s.down(data).walk(val, list)
+               } else {
+                       s.walk(data, list)
                }
-               s.walk(data, list)
        } else if elseList != nil {
                s.walk(data, elseList)
        }
@@ -135,13 +153,14 @@ func isTrue(val reflect.Value) (truth, ok bool) {
 
 func (s *state) walkRange(data reflect.Value, r *rangeNode) {
        val := s.evalPipeline(data, r.pipeline)
+       down := s.down(data)
        switch val.Kind() {
        case reflect.Array, reflect.Slice:
                if val.Len() == 0 {
                        break
                }
                for i := 0; i < val.Len(); i++ {
-                       s.walk(val.Index(i), r.list)
+                       down.walk(val.Index(i), r.list)
                }
                return
        case reflect.Map:
@@ -149,7 +168,7 @@ func (s *state) walkRange(data reflect.Value, r *rangeNode) {
                        break
                }
                for _, key := range val.MapKeys() {
-                       s.walk(val.MapIndex(key), r.list)
+                       down.walk(val.MapIndex(key), r.list)
                }
                return
        default:
@@ -180,7 +199,7 @@ func (s *state) walkTemplate(data reflect.Value, t *templateNode) {
 // The printing of those values happens only through walk functions.
 
 func (s *state) evalPipeline(data reflect.Value, pipe []*commandNode) reflect.Value {
-       value := reflect.Value{}
+       value := zero
        for _, cmd := range pipe {
                value = s.evalCommand(data, cmd, value) // previous value is this one's final arg.
                // If the object has type interface{}, dig down one level to the thing inside.
@@ -197,7 +216,7 @@ func (s *state) evalCommand(data reflect.Value, cmd *commandNode, final reflect.
        case *fieldNode:
                return s.evalFieldNode(data, n, cmd.args, final)
        case *identifierNode:
-               return s.evalFieldOrCall(data, n.ident, cmd.args, final)
+               return s.evalField(data, n.ident, cmd.args, final, true, true)
        }
        if len(cmd.args) > 1 || final.IsValid() {
                s.errorf("can't give argument to non-function %s", cmd.args[0])
@@ -232,10 +251,10 @@ func (s *state) evalFieldNode(data reflect.Value, field *fieldNode, args []node,
        // Up to the last entry, it must be a field.
        n := len(field.ident)
        for i := 0; i < n-1; i++ {
-               data = s.evalField(data, field.ident[i])
+               data = s.evalField(data, field.ident[i], nil, zero, i == 0, false)
        }
        // Now it can be a field or method and if a method, gets arguments.
-       return s.evalFieldOrCall(data, field.ident[n-1], args, final)
+       return s.evalField(data, field.ident[n-1], args, final, len(field.ident) == 1, true)
 }
 
 // Is this an exported - upper case - name?
@@ -244,49 +263,55 @@ func isExported(name string) bool {
        return unicode.IsUpper(rune)
 }
 
-func (s *state) evalField(data reflect.Value, fieldName string) reflect.Value {
-       var isNil bool
-       if data, isNil = indirect(data); isNil {
-               s.errorf("%s is nil pointer", fieldName)
-       }
-       switch data.Kind() {
-       case reflect.Struct:
-               // Is it a field?
-               field := data.FieldByName(fieldName)
-               // TODO: look higher up the tree if we can't find it here. Also unexported fields
-               // might succeed higher up, as map keys.
-               if field.IsValid() && isExported(fieldName) { // valid and exported
-                       return field
-               }
-               s.errorf("%s has no exported field %q", data.Type(), fieldName)
-       default:
-               s.errorf("can't evaluate field %s of type %s", fieldName, data.Type())
-       }
-       panic("not reached")
-}
-
-func (s *state) evalFieldOrCall(data reflect.Value, fieldName string, args []node, final reflect.Value) reflect.Value {
+// evalField evaluates an expression like (.Field) or (.Field arg1 arg2).
+// The 'final' argument represents the return value from the preceding
+// value of the pipeline, if any.
+// If we're in a chain, such as (.X.Y.Z), .X and .Y cannot be methods;
+//canBeMethod will be true only for the last element of such chains (here .Z).
+// The isFirst argument tells whether this is the first element of a chain (here .X).
+// If true, evaluation is allowed to examine the parent to resolve the reference.
+func (s *state) evalField(data reflect.Value, fieldName string, args []node, final reflect.Value,
+isFirst, canBeMethod bool) reflect.Value {
+       topState, topData := s, data // Remember initial state for diagnostics.
        // Is it a function?
        if function, ok := findFunction(fieldName, s.tmpl, s.set); ok {
                return s.evalCall(data, function, fieldName, false, args, final)
        }
-       ptr := data
-       for data.Kind() == reflect.Ptr && !data.IsNil() {
-               ptr, data = data, reflect.Indirect(data)
-       }
-       // Is it a method? We use the pointer because it has value methods too.
-       if method, ok := methodByName(ptr.Type(), fieldName); ok {
-               return s.evalCall(ptr, method.Func, fieldName, true, args, final)
-       }
-       if len(args) > 1 || final.IsValid() {
-               s.errorf("%s is not a method but has arguments", fieldName)
-       }
-       switch data.Kind() {
-       case reflect.Struct:
-               return s.evalField(data, fieldName)
-       default:
-               s.errorf("can't handle evaluation of field %s of type %s", fieldName, data.Type())
+       // Look for methods and fields at this level, and then in the parent.
+       for s != nil {
+               var isNil bool
+               data, isNil = indirect(data)
+               if canBeMethod {
+                       // Need to get to a value of type *T to guarantee we see all
+                       // methods of T and *T.
+                       ptr := data.Addr()
+                       if method, ok := methodByName(ptr.Type(), fieldName); ok {
+                               return s.evalCall(ptr, method.Func, fieldName, true, args, final)
+                       }
+               }
+               // It's not a method; is it a field of a struct?
+               if data.Kind() == reflect.Struct {
+                       field := data.FieldByName(fieldName)
+                       if field.IsValid() {
+                               if len(args) > 1 || final.IsValid() {
+                                       s.errorf("%s is not a method but has arguments", fieldName)
+                               }
+                               if isExported(fieldName) { // valid and exported
+                                       return field
+                               }
+                       }
+               }
+               if !isFirst {
+                       // We check for nil pointers only if there's no possibility of resolution
+                       // in the parent.
+                       if isNil {
+                               s.errorf("nil pointer evaluating %s.%s", topData.Type(), fieldName)
+                       }
+                       break
+               }
+               s, data = s.parent.state, s.parent.data
        }
+       topState.errorf("can't handle evaluation of field %s in type %s", fieldName, topData.Type())
        panic("not reached")
 }
 
@@ -358,7 +383,7 @@ func (s *state) evalCall(v, fun reflect.Value, name string, isMethod bool, args
 
 func (s *state) evalArg(data reflect.Value, typ reflect.Type, n node) reflect.Value {
        if field, ok := n.(*fieldNode); ok {
-               value := s.evalFieldNode(data, field, []node{n}, reflect.Value{})
+               value := s.evalFieldNode(data, field, []node{n}, zero)
                if !value.Type().AssignableTo(typ) {
                        s.errorf("wrong type for value; expected %s; got %s", typ, value.Type())
                }
@@ -453,9 +478,9 @@ func (s *state) evalEmptyInterface(data reflect.Value, typ reflect.Type, n node)
        case *dotNode:
                return data
        case *fieldNode:
-               return s.evalFieldNode(data, n, nil, reflect.Value{})
+               return s.evalFieldNode(data, n, nil, zero)
        case *identifierNode:
-               return s.evalFieldOrCall(data, n.ident, nil, reflect.Value{})
+               return s.evalField(data, n.ident, nil, zero, false, true)
        case *numberNode:
                if n.isComplex {
                        return reflect.ValueOf(n.complex128)
index 86b958e840f04a454abee411dd329102aeb54260..db3e89f63d44c6e3d5363d30c90286611bc031c1 100644 (file)
@@ -44,6 +44,10 @@ type T struct {
        NIL *int
 }
 
+type U struct {
+       V string
+}
+
 var tVal = &T{
        I:      17,
        U16:    16,
@@ -81,7 +85,7 @@ func newIntSlice(n ...int) *[]int {
 
 // Simple methods with and without arguments.
 func (t *T) Method0() string {
-       return "resultOfMethod0"
+       return "M0"
 }
 
 func (t *T) Method1(a int) int {
@@ -120,10 +124,6 @@ func (t *T) EPERM(error bool) (bool, os.Error) {
        return false, nil
 }
 
-type U struct {
-       V string
-}
-
 type execTest struct {
        name   string
        input  string
@@ -169,14 +169,14 @@ var execTests = []execTest{
        {"empty with struct", "{{.Empty4}}", "{v}", tVal, true},
 
        // Method calls.
-       {".Method0", "-{{.Method0}}-", "-resultOfMethod0-", tVal, true},
+       {".Method0", "-{{.Method0}}-", "-M0-", tVal, true},
        {".Method1(1234)", "-{{.Method1 1234}}-", "-1234-", tVal, true},
        {".Method1(.I)", "-{{.Method1 .I}}-", "-17-", tVal, true},
        {".Method2(3, .X)", "-{{.Method2 3 .X}}-", "-Method2: 3 x-", tVal, true},
        {".Method2(.U16, `str`)", "-{{.Method2 .U16 `str`}}-", "-Method2: 16 str-", tVal, true},
 
        // Pipelines.
-       {"pipeline", "-{{.Method0 | .Method2 .U16}}-", "-Method2: 16 resultOfMethod0-", tVal, true},
+       {"pipeline", "-{{.Method0 | .Method2 .U16}}-", "-Method2: 16 M0-", tVal, true},
 
        // If.
        {"if true", "{{if true}}TRUE{{end}}", "TRUE", tVal, true},
@@ -202,8 +202,8 @@ var execTests = []execTest{
        {"printf string", `{{printf "%s" "hello"}}`, "hello", tVal, true},
        {"printf function", `{{printf "%#q" gopher}}`, "`gopher`", tVal, true},
        {"printf field", `{{printf "%s" .U.V}}`, "v", tVal, true},
-       {"printf method", `{{printf "%s" .Method0}}`, "resultOfMethod0", tVal, true},
-       {"printf lots", `{{printf "%d %s %g %s" 127 "hello" 7-3i .Method0}}`, "127 hello (7-3i) resultOfMethod0", tVal, true},
+       {"printf method", `{{printf "%s" .Method0}}`, "M0", tVal, true},
+       {"printf lots", `{{printf "%d %s %g %s" 127 "hello" 7-3i .Method0}}`, "127 hello (7-3i) M0", tVal, true},
 
        // HTML.
        {"html", `{{html "<script>alert(\"XSS\");</script>"}}`,
@@ -249,6 +249,12 @@ var execTests = []execTest{
        {"with map", "{{with .MSIone}}{{.}}{{else}}EMPTY{{end}}", "map[one:1]", tVal, true},
        {"with empty interface, struct field", "{{with .Empty4}}{{.V}}{{end}}", "v", tVal, true},
 
+       // Fields and methods in parent struct.
+       {"with .U, get .I", "{{with .U}}{{.I}}{{end}}", "17", tVal, true},
+       {"with .U, do .Method0", "{{with .U}}{{.Method0}}{{end}}", "M0", tVal, true},
+       {"range .SI .I", "{{range .SI}}<{{.I}}>{{end}}", "<17><17><17>", tVal, true},
+       {"range .SI .Method0", "{{range .SI}}{{.Method0}}{{end}}", "M0M0M0", tVal, true},
+
        // Range.
        {"range []int", "{{range .SI}}-{{.}}-{{end}}", "-3--4--5-", tVal, true},
        {"range empty no else", "{{range .SIEmpty}}-{{.}}-{{end}}", "", tVal, true},