func (s *state) evalCommand(data reflect.Value, cmd *commandNode, final reflect.Value) reflect.Value {
firstWord := cmd.args[0]
- if field, ok := firstWord.(*fieldNode); ok {
- return s.evalFieldNode(data, field, cmd.args, final)
+ switch n := firstWord.(type) {
+ case *fieldNode:
+ return s.evalFieldNode(data, n, cmd.args, final)
+ case *identifierNode:
+ return s.evalFieldOrCall(data, n.ident, cmd.args, final)
}
if len(cmd.args) > 1 || final.IsValid() {
// TODO: functions
data = s.evalField(data, field.ident[i])
}
// Now it can be a field or method and if a method, gets arguments.
- return s.evalMethodOrField(data, field.ident[n-1], args, final)
+ return s.evalFieldOrCall(data, field.ident[n-1], args, final)
}
func (s *state) evalField(data reflect.Value, fieldName string) reflect.Value {
panic("not reached")
}
-func (s *state) evalMethodOrField(data reflect.Value, fieldName string, args []node, final reflect.Value) reflect.Value {
+func (s *state) evalFieldOrCall(data reflect.Value, fieldName string, args []node, final reflect.Value) reflect.Value {
+ // Is it a function?
+ if function, ok := findFunction(fieldName, s.tmpl, s.set); ok {
+ return s.evalCall(data, function, fieldName, false, args, final)
+ }
ptr := data
for data.Kind() == reflect.Ptr {
ptr, data = data, reflect.Indirect(data)
}
// Is it a method? We use the pointer because it has value methods too.
if method, ok := ptr.Type().MethodByName(fieldName); ok {
- return s.evalMethod(ptr, method, args, final)
+ return s.evalCall(ptr, method.Func, fieldName, true, args, final)
}
if len(args) > 1 || final.IsValid() {
s.errorf("%s is not a method but has arguments", fieldName)
osErrorType = reflect.TypeOf(new(os.Error)).Elem()
)
-func (s *state) evalMethod(v reflect.Value, method reflect.Method, args []node, final reflect.Value) reflect.Value {
- typ := method.Type
- fun := method.Func
+func (s *state) evalCall(v, fun reflect.Value, name string, isMethod bool, args []node, final reflect.Value) reflect.Value {
+ typ := fun.Type()
+ if !isMethod && len(args) > 0 { // Args will be nil if it's a niladic call in an argument list
+ args = args[1:] // first arg is name of function; not used in call.
+ }
numIn := len(args)
if final.IsValid() {
numIn++
}
- if !typ.IsVariadic() && numIn < typ.NumIn()-1 || !typ.IsVariadic() && numIn != typ.NumIn() {
- s.errorf("wrong number of args for %s: want %d got %d", method.Name, typ.NumIn(), len(args))
+ numFixed := len(args)
+ if typ.IsVariadic() {
+ numFixed = typ.NumIn() - 1 // last arg is the variadic one.
+ if numIn < numFixed {
+ s.errorf("wrong number of args for %s: want at least %d got %d", name, typ.NumIn()-1, len(args))
+ }
+ } else if numIn < typ.NumIn()-1 || !typ.IsVariadic() && numIn != typ.NumIn() {
+ s.errorf("wrong number of args for %s: want %d got %d", name, typ.NumIn(), len(args))
}
- // We allow methods with 1 result or 2 results where the second is an os.Error.
- switch {
- case typ.NumOut() == 1:
- case typ.NumOut() == 2 && typ.Out(1) == osErrorType:
- default:
- s.errorf("can't handle multiple results from method %q", method.Name)
+ if !goodFunc(typ) {
+ s.errorf("can't handle multiple results from method/function %q", name)
}
// Build the arg list.
argv := make([]reflect.Value, numIn)
// First arg is the receiver.
- argv[0] = v
- // Others must be evaluated.
- for i := 1; i < len(args); i++ {
+ i := 0
+ if isMethod {
+ argv[0] = v
+ i++
+ }
+ // Others must be evaluated. Fixed args first.
+ for ; i < numFixed; i++ {
argv[i] = s.evalArg(v, typ.In(i), args[i])
}
+ // And now the ... args.
+ if typ.IsVariadic() {
+ argType := typ.In(typ.NumIn() - 1).Elem() // Argument is a slice.
+ for ; i < len(args); i++ {
+ argv[i] = s.evalArg(v, argType, args[i])
+ }
+ }
// Add final value if necessary.
if final.IsValid() {
argv[len(args)] = final
}
switch typ.Kind() {
case reflect.Bool:
- return s.evalBool(data, typ, n)
+ return s.evalBool(typ, n)
case reflect.String:
- return s.evalString(data, typ, n)
+ return s.evalString(typ, n)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return s.evalInteger(data, typ, n)
+ return s.evalInteger(typ, n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- return s.evalUnsignedInteger(data, typ, n)
+ return s.evalUnsignedInteger(typ, n)
case reflect.Float32, reflect.Float64:
- return s.evalFloat(data, typ, n)
+ return s.evalFloat(typ, n)
case reflect.Complex64, reflect.Complex128:
- return s.evalComplex(data, typ, n)
+ return s.evalComplex(typ, n)
+ case reflect.Interface:
+ if typ.NumMethod() == 0 {
+ return s.evalEmptyInterface(data, typ, n)
+ }
}
- s.errorf("can't handle node %s for method arg of type %s", n, typ)
+ s.errorf("can't handle %s for arg of type %s", n, typ)
panic("not reached")
}
-func (s *state) evalBool(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalBool(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*boolNode); ok {
value := reflect.New(typ).Elem()
value.SetBool(n.true)
panic("not reached")
}
-func (s *state) evalString(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalString(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*stringNode); ok {
value := reflect.New(typ).Elem()
value.SetString(n.text)
panic("not reached")
}
-func (s *state) evalInteger(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalInteger(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*numberNode); ok && n.isInt {
value := reflect.New(typ).Elem()
value.SetInt(n.int64)
panic("not reached")
}
-func (s *state) evalUnsignedInteger(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalUnsignedInteger(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*numberNode); ok && n.isUint {
value := reflect.New(typ).Elem()
value.SetUint(n.uint64)
panic("not reached")
}
-func (s *state) evalFloat(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalFloat(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*numberNode); ok && n.isFloat {
value := reflect.New(typ).Elem()
value.SetFloat(n.float64)
panic("not reached")
}
-func (s *state) evalComplex(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalComplex(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*numberNode); ok && n.isComplex {
value := reflect.New(typ).Elem()
value.SetComplex(n.complex128)
panic("not reached")
}
+func (s *state) evalEmptyInterface(data reflect.Value, typ reflect.Type, n node) reflect.Value {
+ switch n := n.(type) {
+ case *boolNode:
+ return reflect.ValueOf(n.true)
+ case *fieldNode:
+ return s.evalFieldNode(data, n, nil, reflect.Value{})
+ case *identifierNode:
+ return s.evalFieldOrCall(data, n.ident, nil, reflect.Value{})
+ case *numberNode:
+ if n.isComplex {
+ return reflect.ValueOf(n.complex128)
+ }
+ if n.isInt {
+ return reflect.ValueOf(n.int64)
+ }
+ if n.isUint {
+ return reflect.ValueOf(n.uint64)
+ }
+ if n.isFloat {
+ return reflect.ValueOf(n.float64)
+ }
+ case *stringNode:
+ return reflect.ValueOf(n.text)
+ }
+ s.errorf("can't handle assignment of %s to empty interface argument", n)
+ panic("not reached")
+}
+
// printValue writes the textual representation of the value to the output of
// the template.
func (s *state) printValue(n node, v reflect.Value) {
{"if slice", "{{if .SI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true},
{"if emptymap", "{{if .MSIEmpty}}NON-EMPTY{{else}}EMPTY{{end}}", "EMPTY", tVal, true},
{"if map", "{{if .MSI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true},
+ // Function calls.
+ {"printf", `{{printf "hello, printf"}}`, "hello, printf", tVal, true},
+ {"printf int", `{{printf "%04x" 127}}`, "007f", tVal, true},
+ {"printf float", `{{printf "%g" 3.5}}`, "3.5", tVal, true},
+ {"printf complex", `{{printf "%g" 1+7i}}`, "(1+7i)", tVal, true},
+ {"printf string", `{{printf "%s" "hello"}}`, "hello", tVal, true},
+ {"printf function", `{{printf "%#q" gopher}}`, "`gopher`", tVal, true},
+ {"printf field", `{{printf "%s" .U.V}}`, "v", tVal, true},
+ {"printf method", `{{printf "%s" .Method0}}`, "resultOfMethod0", tVal, true},
+ {"printf lots", `{{printf "%d %s %g %s" 127 "hello" 7-3i .Method0}}`, "127 hello (7-3i) resultOfMethod0", tVal, true},
// With.
{"with true", "{{with true}}{{.}}{{end}}", "true", tVal, true},
{"with false", "{{with false}}{{.}}{{else}}FALSE{{end}}", "FALSE", tVal, true},
{"error method, no error", "{{.EPERM false}}", "false", tVal, true},
}
+func gopher() string {
+ return "gopher"
+}
+
func testExecute(execTests []execTest, set *Set, t *testing.T) {
b := new(bytes.Buffer)
+ funcs := FuncMap{"gopher": gopher}
for _, test := range execTests {
- tmpl := New(test.name)
+ tmpl := New(test.name).Funcs(funcs)
err := tmpl.Parse(test.input)
if err != nil {
t.Errorf("%s: parse error: %s", test.name, err)
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package template
+
+import (
+ "fmt"
+ "reflect"
+)
+
+// FuncMap is the type of the map defining the mapping from names to functions.
+// Each function must have either a single return value, or two return values of
+// which the second has type os.Error.
+type FuncMap map[string]interface{}
+
+var funcs = map[string]reflect.Value{
+ "printf": reflect.ValueOf(fmt.Sprintf),
+}
+
+// addFuncs adds to values the functions in funcs, converting them to reflect.Values.
+func addFuncs(values map[string]reflect.Value, funcMap FuncMap) {
+ for name, fn := range funcMap {
+ v := reflect.ValueOf(fn)
+ if v.Kind() != reflect.Func {
+ panic("value for " + name + " not a function")
+ }
+ if !goodFunc(v.Type()) {
+ panic(fmt.Errorf("can't handle multiple results from method/function %q", name))
+ }
+ values[name] = v
+ }
+}
+
+// goodFunc checks that the function or method has the right result signature.
+func goodFunc(typ reflect.Type) bool {
+ // We allow functions with 1 result or 2 results where the second is an os.Error.
+ switch {
+ case typ.NumOut() == 1:
+ return true
+ case typ.NumOut() == 2 && typ.Out(1) == osErrorType:
+ return true
+ }
+ return false
+}
+
+// findFunction looks for a function in the template, set, and global map.
+func findFunction(name string, tmpl *Template, set *Set) (reflect.Value, bool) {
+ if tmpl != nil {
+ if fn := tmpl.funcs[name]; fn.IsValid() {
+ return fn, true
+ }
+ }
+ if set != nil {
+ if fn := set.funcs[name]; fn.IsValid() {
+ return fn, true
+ }
+ }
+ if fn := funcs[name]; fn.IsValid() {
+ return fn, true
+ }
+ return reflect.Value{}, false
+}