From b717768b94d1afed70eb035221327fa9fee371c8 Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Sun, 9 May 2010 16:40:38 -0700 Subject: [PATCH] template: regularize the handling of interfaces, pointers, and methods when looking up names. Fixes #764. R=rsc CC=golang-dev https://golang.org/cl/1170041 --- src/pkg/template/template.go | 135 +++++++++++++++-------------------- 1 file changed, 56 insertions(+), 79 deletions(-) diff --git a/src/pkg/template/template.go b/src/pkg/template/template.go index 73789c23af..334559c13c 100644 --- a/src/pkg/template/template.go +++ b/src/pkg/template/template.go @@ -575,6 +575,55 @@ func (t *Template) parse() { // -- Execution +// Evaluate interfaces and pointers looking for a value that can look up the name, via a +// struct field, method, or map key, and return the result of the lookup. +func lookup(v reflect.Value, name string) reflect.Value { + for v != nil { + typ := v.Type() + if n := v.Type().NumMethod(); n > 0 { + for i := 0; i < n; i++ { + m := typ.Method(i) + mtyp := m.Type + // We must check receiver type because of a bug in the reflection type tables: + // it should not be possible to find a method with the wrong receiver type but + // this can happen due to value/pointer receiver mismatch. + if m.Name == name && mtyp.NumIn() == 1 && mtyp.NumOut() == 1 && mtyp.In(0) == typ { + return v.Method(i).Call(nil)[0] + } + } + } + switch av := v.(type) { + case *reflect.PtrValue: + v = av.Elem() + case *reflect.InterfaceValue: + v = av.Elem() + case *reflect.StructValue: + return av.FieldByName(name) + case *reflect.MapValue: + return av.Elem(reflect.NewValue(name)) + default: + return nil + } + } + return v +} + +// Walk v through pointers and interfaces, extracting the elements within. +func indirect(v reflect.Value) reflect.Value { +loop: + for v != nil { + switch av := v.(type) { + case *reflect.PtrValue: + v = av.Elem() + case *reflect.InterfaceValue: + v = av.Elem() + default: + break loop + } + } + return v +} + // If the data for this template is a struct, find the named variable. // Names of the form a.b.c are walked down the data tree. // The special name "@" (the "cursor") denotes the current data. @@ -587,88 +636,18 @@ func (st *state) findVar(s string) reflect.Value { } data := st.data for _, elem := range strings.Split(s, ".", 0) { - origData := data // for method lookup need value before indirection. // Look up field; data must be a struct or map. - data = reflect.Indirect(data) + data = lookup(data, elem) if data == nil { return nil } - if intf, ok := data.(*reflect.InterfaceValue); ok { - data = reflect.Indirect(intf.Elem()) - } - - switch typ := data.Type().(type) { - case *reflect.StructType: - if field, ok := typ.FieldByName(elem); ok { - data = data.(*reflect.StructValue).FieldByIndex(field.Index) - continue - } - case *reflect.MapType: - data = data.(*reflect.MapValue).Elem(reflect.NewValue(elem)) - continue - } - - // No luck with that name; is it a method? - if result, found := callMethod(origData, elem); found { - data = result - continue - } - return nil } return data } -// See if name is a method of the value at some level of indirection. -// The return values are the result of the call (which may be nil if -// there's trouble) and whether a method of the right name exists with -// any signature. -func callMethod(data reflect.Value, name string) (result reflect.Value, found bool) { - found = false - // Method set depends on pointerness, and the value may be arbitrarily - // indirect. Simplest approach is to walk down the pointer chain and - // see if we can find the method at each step. - // Most steps will see NumMethod() == 0. - for { - typ := data.Type() - if nMethod := data.Type().NumMethod(); nMethod > 0 { - for i := 0; i < nMethod; i++ { - method := typ.Method(i) - if method.Name == name { - found = true // we found the name regardless - // does receiver type match? (pointerness might be off) - if typ == method.Type.In(0) { - return call(data, method), found - } - } - } - } - if nd, ok := data.(*reflect.PtrValue); ok { - data = nd.Elem() - } else { - break - } - } - return -} - -// Invoke the method. If its signature is wrong, return nil. -func call(v reflect.Value, method reflect.Method) reflect.Value { - funcType := method.Type - // Method must take no arguments, meaning as a func it has one argument (the receiver) - if funcType.NumIn() != 1 { - return nil - } - // Method must return a single value. - if funcType.NumOut() != 1 { - return nil - } - // Result will be the zeroth element of the returned slice. - return method.Func.Call([]reflect.Value{v})[0] -} - // Is there no data to look at? func empty(v reflect.Value) bool { - v = reflect.Indirect(v) + v = indirect(v) if v == nil { return true } @@ -694,13 +673,10 @@ func (t *Template) varValue(name string, st *state) reflect.Value { field := st.findVar(name) if field == nil { if st.parent == nil { - t.execError(st, t.linenum, "name not found: %s", name) + t.execError(st, t.linenum, "name not found: %s in type %s", name, st.data.Type()) } return t.varValue(name, st.parent) } - if iface, ok := field.(*reflect.InterfaceValue); ok && !iface.IsNil() { - field = iface.Elem() - } return field } @@ -760,7 +736,7 @@ func (t *Template) executeSection(s *sectionElement, st *state) { // Find driver data for this section. It must be in the current struct. field := t.varValue(s.field, st) if field == nil { - t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, reflect.Indirect(st.data).Type()) + t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, st.data.Type()) } st = st.clone(field) start, end := s.start, s.or @@ -805,8 +781,9 @@ func (t *Template) executeRepeated(r *repeatedElement, st *state) { // Find driver data for this section. It must be in the current struct. field := t.varValue(r.field, st) if field == nil { - t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, reflect.Indirect(st.data).Type()) + t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, st.data.Type()) } + field = indirect(field) start, end := r.start, r.or if end < 0 { @@ -947,7 +924,7 @@ func Parse(s string, fmap FormatterMap) (t *Template, err os.Error) { } // ParseFile is a wrapper function that creates a Template with default -// parameters (such as {} for // metacharacters). The filename identfies +// parameters (such as {} for metacharacters). The filename identifies // a file containing the template text, while the formatter map fmap, which // may be nil, defines auxiliary functions for formatting variables. // The template is returned. If any errors occur, err will be non-nil. -- 2.48.1