]> Cypherpunks repositories - gostls13.git/commitdiff
template: regularize the handling of interfaces, pointers, and
authorRob Pike <r@golang.org>
Sun, 9 May 2010 23:40:38 +0000 (16:40 -0700)
committerRob Pike <r@golang.org>
Sun, 9 May 2010 23:40:38 +0000 (16:40 -0700)
methods when looking up names.
Fixes #764.

R=rsc
CC=golang-dev
https://golang.org/cl/1170041

src/pkg/template/template.go

index 73789c23afb5e8b79f0dcb074fea2f735875b188..334559c13ca6d8d682549f25f6fb6cf5c7a099f1 100644 (file)
@@ -575,6 +575,55 @@ func (t *Template) parse() {
 
 // -- Execution
 
+// Evaluate interfaces and pointers looking for a value that can look up the name, via a
+// struct field, method, or map key, and return the result of the lookup.
+func lookup(v reflect.Value, name string) reflect.Value {
+       for v != nil {
+               typ := v.Type()
+               if n := v.Type().NumMethod(); n > 0 {
+                       for i := 0; i < n; i++ {
+                               m := typ.Method(i)
+                               mtyp := m.Type
+                               // We must check receiver type because of a bug in the reflection type tables:
+                               // it should not be possible to find a method with the wrong receiver type but
+                               // this can happen due to value/pointer receiver mismatch.
+                               if m.Name == name && mtyp.NumIn() == 1 && mtyp.NumOut() == 1 && mtyp.In(0) == typ {
+                                       return v.Method(i).Call(nil)[0]
+                               }
+                       }
+               }
+               switch av := v.(type) {
+               case *reflect.PtrValue:
+                       v = av.Elem()
+               case *reflect.InterfaceValue:
+                       v = av.Elem()
+               case *reflect.StructValue:
+                       return av.FieldByName(name)
+               case *reflect.MapValue:
+                       return av.Elem(reflect.NewValue(name))
+               default:
+                       return nil
+               }
+       }
+       return v
+}
+
+// Walk v through pointers and interfaces, extracting the elements within.
+func indirect(v reflect.Value) reflect.Value {
+loop:
+       for v != nil {
+               switch av := v.(type) {
+               case *reflect.PtrValue:
+                       v = av.Elem()
+               case *reflect.InterfaceValue:
+                       v = av.Elem()
+               default:
+                       break loop
+               }
+       }
+       return v
+}
+
 // If the data for this template is a struct, find the named variable.
 // Names of the form a.b.c are walked down the data tree.
 // The special name "@" (the "cursor") denotes the current data.
@@ -587,88 +636,18 @@ func (st *state) findVar(s string) reflect.Value {
        }
        data := st.data
        for _, elem := range strings.Split(s, ".", 0) {
-               origData := data // for method lookup need value before indirection.
                // Look up field; data must be a struct or map.
-               data = reflect.Indirect(data)
+               data = lookup(data, elem)
                if data == nil {
                        return nil
                }
-               if intf, ok := data.(*reflect.InterfaceValue); ok {
-                       data = reflect.Indirect(intf.Elem())
-               }
-
-               switch typ := data.Type().(type) {
-               case *reflect.StructType:
-                       if field, ok := typ.FieldByName(elem); ok {
-                               data = data.(*reflect.StructValue).FieldByIndex(field.Index)
-                               continue
-                       }
-               case *reflect.MapType:
-                       data = data.(*reflect.MapValue).Elem(reflect.NewValue(elem))
-                       continue
-               }
-
-               // No luck with that name; is it a method?
-               if result, found := callMethod(origData, elem); found {
-                       data = result
-                       continue
-               }
-               return nil
        }
        return data
 }
 
-// See if name is a method of the value at some level of indirection.
-// The return values are the result of the call (which may be nil if
-// there's trouble) and whether a method of the right name exists with
-// any signature.
-func callMethod(data reflect.Value, name string) (result reflect.Value, found bool) {
-       found = false
-       // Method set depends on pointerness, and the value may be arbitrarily
-       // indirect.  Simplest approach is to walk down the pointer chain and
-       // see if we can find the method at each step.
-       // Most steps will see NumMethod() == 0.
-       for {
-               typ := data.Type()
-               if nMethod := data.Type().NumMethod(); nMethod > 0 {
-                       for i := 0; i < nMethod; i++ {
-                               method := typ.Method(i)
-                               if method.Name == name {
-                                       found = true // we found the name regardless
-                                       // does receiver type match? (pointerness might be off)
-                                       if typ == method.Type.In(0) {
-                                               return call(data, method), found
-                                       }
-                               }
-                       }
-               }
-               if nd, ok := data.(*reflect.PtrValue); ok {
-                       data = nd.Elem()
-               } else {
-                       break
-               }
-       }
-       return
-}
-
-// Invoke the method. If its signature is wrong, return nil.
-func call(v reflect.Value, method reflect.Method) reflect.Value {
-       funcType := method.Type
-       // Method must take no arguments, meaning as a func it has one argument (the receiver)
-       if funcType.NumIn() != 1 {
-               return nil
-       }
-       // Method must return a single value.
-       if funcType.NumOut() != 1 {
-               return nil
-       }
-       // Result will be the zeroth element of the returned slice.
-       return method.Func.Call([]reflect.Value{v})[0]
-}
-
 // Is there no data to look at?
 func empty(v reflect.Value) bool {
-       v = reflect.Indirect(v)
+       v = indirect(v)
        if v == nil {
                return true
        }
@@ -694,13 +673,10 @@ func (t *Template) varValue(name string, st *state) reflect.Value {
        field := st.findVar(name)
        if field == nil {
                if st.parent == nil {
-                       t.execError(st, t.linenum, "name not found: %s", name)
+                       t.execError(st, t.linenum, "name not found: %s in type %s", name, st.data.Type())
                }
                return t.varValue(name, st.parent)
        }
-       if iface, ok := field.(*reflect.InterfaceValue); ok && !iface.IsNil() {
-               field = iface.Elem()
-       }
        return field
 }
 
@@ -760,7 +736,7 @@ func (t *Template) executeSection(s *sectionElement, st *state) {
        // Find driver data for this section.  It must be in the current struct.
        field := t.varValue(s.field, st)
        if field == nil {
-               t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, reflect.Indirect(st.data).Type())
+               t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, st.data.Type())
        }
        st = st.clone(field)
        start, end := s.start, s.or
@@ -805,8 +781,9 @@ func (t *Template) executeRepeated(r *repeatedElement, st *state) {
        // Find driver data for this section.  It must be in the current struct.
        field := t.varValue(r.field, st)
        if field == nil {
-               t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, reflect.Indirect(st.data).Type())
+               t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, st.data.Type())
        }
+       field = indirect(field)
 
        start, end := r.start, r.or
        if end < 0 {
@@ -947,7 +924,7 @@ func Parse(s string, fmap FormatterMap) (t *Template, err os.Error) {
 }
 
 // ParseFile is a wrapper function that creates a Template with default
-// parameters (such as {} for // metacharacters).  The filename identfies
+// parameters (such as {} for metacharacters).  The filename identifies
 // a file containing the template text, while the formatter map fmap, which
 // may be nil, defines auxiliary functions for formatting variables.
 // The template is returned. If any errors occur, err will be non-nil.