]> Cypherpunks repositories - gostls13.git/commitdiff
exp/template: functions
authorRob Pike <r@golang.org>
Tue, 5 Jul 2011 04:23:51 +0000 (14:23 +1000)
committerRob Pike <r@golang.org>
Tue, 5 Jul 2011 04:23:51 +0000 (14:23 +1000)
Add the ability to attach functions to template and template sets.
Make variadic functions and methods work.
Still to come: static checking of function names during parse.

R=golang-dev, dsymonds
CC=golang-dev
https://golang.org/cl/4643068

src/pkg/exp/template/Makefile
src/pkg/exp/template/exec.go
src/pkg/exp/template/exec_test.go
src/pkg/exp/template/funcs.go [new file with mode: 0644]
src/pkg/exp/template/parse.go
src/pkg/exp/template/set.go

index 50a0bd7234a369431eb28603739133c64e665de9..8550b0d522103fbbe20e7b1057d8c62204484d68 100644 (file)
@@ -7,6 +7,7 @@ include ../../../Make.inc
 TARG=exp/template
 GOFILES=\
        exec.go\
+       funcs.go\
        lex.go\
        parse.go\
        set.go\
index 3eaecd194164d52b7bd57fdcc05a1ac101e672ee..636bc4c334793fcab0b1c853a7440d1da58dd77f 100644 (file)
@@ -174,8 +174,11 @@ func (s *state) evalPipeline(data reflect.Value, pipe []*commandNode) reflect.Va
 
 func (s *state) evalCommand(data reflect.Value, cmd *commandNode, final reflect.Value) reflect.Value {
        firstWord := cmd.args[0]
-       if field, ok := firstWord.(*fieldNode); ok {
-               return s.evalFieldNode(data, field, cmd.args, final)
+       switch n := firstWord.(type) {
+       case *fieldNode:
+               return s.evalFieldNode(data, n, cmd.args, final)
+       case *identifierNode:
+               return s.evalFieldOrCall(data, n.ident, cmd.args, final)
        }
        if len(cmd.args) > 1 || final.IsValid() {
                // TODO: functions
@@ -215,7 +218,7 @@ func (s *state) evalFieldNode(data reflect.Value, field *fieldNode, args []node,
                data = s.evalField(data, field.ident[i])
        }
        // Now it can be a field or method and if a method, gets arguments.
-       return s.evalMethodOrField(data, field.ident[n-1], args, final)
+       return s.evalFieldOrCall(data, field.ident[n-1], args, final)
 }
 
 func (s *state) evalField(data reflect.Value, fieldName string) reflect.Value {
@@ -238,14 +241,18 @@ func (s *state) evalField(data reflect.Value, fieldName string) reflect.Value {
        panic("not reached")
 }
 
-func (s *state) evalMethodOrField(data reflect.Value, fieldName string, args []node, final reflect.Value) reflect.Value {
+func (s *state) evalFieldOrCall(data reflect.Value, fieldName string, args []node, final reflect.Value) reflect.Value {
+       // Is it a function?
+       if function, ok := findFunction(fieldName, s.tmpl, s.set); ok {
+               return s.evalCall(data, function, fieldName, false, args, final)
+       }
        ptr := data
        for data.Kind() == reflect.Ptr {
                ptr, data = data, reflect.Indirect(data)
        }
        // Is it a method? We use the pointer because it has value methods too.
        if method, ok := ptr.Type().MethodByName(fieldName); ok {
-               return s.evalMethod(ptr, method, args, final)
+               return s.evalCall(ptr, method.Func, fieldName, true, args, final)
        }
        if len(args) > 1 || final.IsValid() {
                s.errorf("%s is not a method but has arguments", fieldName)
@@ -263,31 +270,46 @@ var (
        osErrorType = reflect.TypeOf(new(os.Error)).Elem()
 )
 
-func (s *state) evalMethod(v reflect.Value, method reflect.Method, args []node, final reflect.Value) reflect.Value {
-       typ := method.Type
-       fun := method.Func
+func (s *state) evalCall(v, fun reflect.Value, name string, isMethod bool, args []node, final reflect.Value) reflect.Value {
+       typ := fun.Type()
+       if !isMethod && len(args) > 0 { // Args will be nil if it's a niladic call in an argument list
+               args = args[1:] // first arg is name of function; not used in call.
+       }
        numIn := len(args)
        if final.IsValid() {
                numIn++
        }
-       if !typ.IsVariadic() && numIn < typ.NumIn()-1 || !typ.IsVariadic() && numIn != typ.NumIn() {
-               s.errorf("wrong number of args for %s: want %d got %d", method.Name, typ.NumIn(), len(args))
+       numFixed := len(args)
+       if typ.IsVariadic() {
+               numFixed = typ.NumIn() - 1 // last arg is the variadic one.
+               if numIn < numFixed {
+                       s.errorf("wrong number of args for %s: want at least %d got %d", name, typ.NumIn()-1, len(args))
+               }
+       } else if numIn < typ.NumIn()-1 || !typ.IsVariadic() && numIn != typ.NumIn() {
+               s.errorf("wrong number of args for %s: want %d got %d", name, typ.NumIn(), len(args))
        }
-       // We allow methods with 1 result or 2 results where the second is an os.Error.
-       switch {
-       case typ.NumOut() == 1:
-       case typ.NumOut() == 2 && typ.Out(1) == osErrorType:
-       default:
-               s.errorf("can't handle multiple results from method %q", method.Name)
+       if !goodFunc(typ) {
+               s.errorf("can't handle multiple results from method/function %q", name)
        }
        // Build the arg list.
        argv := make([]reflect.Value, numIn)
        // First arg is the receiver.
-       argv[0] = v
-       // Others must be evaluated.
-       for i := 1; i < len(args); i++ {
+       i := 0
+       if isMethod {
+               argv[0] = v
+               i++
+       }
+       // Others must be evaluated. Fixed args first.
+       for ; i < numFixed; i++ {
                argv[i] = s.evalArg(v, typ.In(i), args[i])
        }
+       // And now the ... args.
+       if typ.IsVariadic() {
+               argType := typ.In(typ.NumIn() - 1).Elem() // Argument is a slice.
+               for ; i < len(args); i++ {
+                       argv[i] = s.evalArg(v, argType, args[i])
+               }
+       }
        // Add final value if necessary.
        if final.IsValid() {
                argv[len(args)] = final
@@ -310,23 +332,27 @@ func (s *state) evalArg(data reflect.Value, typ reflect.Type, n node) reflect.Va
        }
        switch typ.Kind() {
        case reflect.Bool:
-               return s.evalBool(data, typ, n)
+               return s.evalBool(typ, n)
        case reflect.String:
-               return s.evalString(data, typ, n)
+               return s.evalString(typ, n)
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-               return s.evalInteger(data, typ, n)
+               return s.evalInteger(typ, n)
        case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
-               return s.evalUnsignedInteger(data, typ, n)
+               return s.evalUnsignedInteger(typ, n)
        case reflect.Float32, reflect.Float64:
-               return s.evalFloat(data, typ, n)
+               return s.evalFloat(typ, n)
        case reflect.Complex64, reflect.Complex128:
-               return s.evalComplex(data, typ, n)
+               return s.evalComplex(typ, n)
+       case reflect.Interface:
+               if typ.NumMethod() == 0 {
+                       return s.evalEmptyInterface(data, typ, n)
+               }
        }
-       s.errorf("can't handle node %s for method arg of type %s", n, typ)
+       s.errorf("can't handle %s for arg of type %s", n, typ)
        panic("not reached")
 }
 
-func (s *state) evalBool(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalBool(typ reflect.Type, n node) reflect.Value {
        if n, ok := n.(*boolNode); ok {
                value := reflect.New(typ).Elem()
                value.SetBool(n.true)
@@ -336,7 +362,7 @@ func (s *state) evalBool(v reflect.Value, typ reflect.Type, n node) reflect.Valu
        panic("not reached")
 }
 
-func (s *state) evalString(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalString(typ reflect.Type, n node) reflect.Value {
        if n, ok := n.(*stringNode); ok {
                value := reflect.New(typ).Elem()
                value.SetString(n.text)
@@ -346,7 +372,7 @@ func (s *state) evalString(v reflect.Value, typ reflect.Type, n node) reflect.Va
        panic("not reached")
 }
 
-func (s *state) evalInteger(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalInteger(typ reflect.Type, n node) reflect.Value {
        if n, ok := n.(*numberNode); ok && n.isInt {
                value := reflect.New(typ).Elem()
                value.SetInt(n.int64)
@@ -356,7 +382,7 @@ func (s *state) evalInteger(v reflect.Value, typ reflect.Type, n node) reflect.V
        panic("not reached")
 }
 
-func (s *state) evalUnsignedInteger(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalUnsignedInteger(typ reflect.Type, n node) reflect.Value {
        if n, ok := n.(*numberNode); ok && n.isUint {
                value := reflect.New(typ).Elem()
                value.SetUint(n.uint64)
@@ -366,7 +392,7 @@ func (s *state) evalUnsignedInteger(v reflect.Value, typ reflect.Type, n node) r
        panic("not reached")
 }
 
-func (s *state) evalFloat(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalFloat(typ reflect.Type, n node) reflect.Value {
        if n, ok := n.(*numberNode); ok && n.isFloat {
                value := reflect.New(typ).Elem()
                value.SetFloat(n.float64)
@@ -376,7 +402,7 @@ func (s *state) evalFloat(v reflect.Value, typ reflect.Type, n node) reflect.Val
        panic("not reached")
 }
 
-func (s *state) evalComplex(v reflect.Value, typ reflect.Type, n node) reflect.Value {
+func (s *state) evalComplex(typ reflect.Type, n node) reflect.Value {
        if n, ok := n.(*numberNode); ok && n.isComplex {
                value := reflect.New(typ).Elem()
                value.SetComplex(n.complex128)
@@ -386,6 +412,34 @@ func (s *state) evalComplex(v reflect.Value, typ reflect.Type, n node) reflect.V
        panic("not reached")
 }
 
+func (s *state) evalEmptyInterface(data reflect.Value, typ reflect.Type, n node) reflect.Value {
+       switch n := n.(type) {
+       case *boolNode:
+               return reflect.ValueOf(n.true)
+       case *fieldNode:
+               return s.evalFieldNode(data, n, nil, reflect.Value{})
+       case *identifierNode:
+               return s.evalFieldOrCall(data, n.ident, nil, reflect.Value{})
+       case *numberNode:
+               if n.isComplex {
+                       return reflect.ValueOf(n.complex128)
+               }
+               if n.isInt {
+                       return reflect.ValueOf(n.int64)
+               }
+               if n.isUint {
+                       return reflect.ValueOf(n.uint64)
+               }
+               if n.isFloat {
+                       return reflect.ValueOf(n.float64)
+               }
+       case *stringNode:
+               return reflect.ValueOf(n.text)
+       }
+       s.errorf("can't handle assignment of %s to empty interface argument", n)
+       panic("not reached")
+}
+
 // printValue writes the textual representation of the value to the output of
 // the template.
 func (s *state) printValue(n node, v reflect.Value) {
index 6e4da692e95d31f5e15c3203bb952d590771c745..8784a0b9fdd4afe6664dd980252101f9cc112804 100644 (file)
@@ -140,6 +140,16 @@ var execTests = []execTest{
        {"if slice", "{{if .SI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true},
        {"if emptymap", "{{if .MSIEmpty}}NON-EMPTY{{else}}EMPTY{{end}}", "EMPTY", tVal, true},
        {"if map", "{{if .MSI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true},
+       // Function calls.
+       {"printf", `{{printf "hello, printf"}}`, "hello, printf", tVal, true},
+       {"printf int", `{{printf "%04x" 127}}`, "007f", tVal, true},
+       {"printf float", `{{printf "%g" 3.5}}`, "3.5", tVal, true},
+       {"printf complex", `{{printf "%g" 1+7i}}`, "(1+7i)", tVal, true},
+       {"printf string", `{{printf "%s" "hello"}}`, "hello", tVal, true},
+       {"printf function", `{{printf "%#q" gopher}}`, "`gopher`", tVal, true},
+       {"printf field", `{{printf "%s" .U.V}}`, "v", tVal, true},
+       {"printf method", `{{printf "%s" .Method0}}`, "resultOfMethod0", tVal, true},
+       {"printf lots", `{{printf "%d %s %g %s" 127 "hello" 7-3i .Method0}}`, "127 hello (7-3i) resultOfMethod0", tVal, true},
        // With.
        {"with true", "{{with true}}{{.}}{{end}}", "true", tVal, true},
        {"with false", "{{with false}}{{.}}{{else}}FALSE{{end}}", "FALSE", tVal, true},
@@ -171,10 +181,15 @@ var execTests = []execTest{
        {"error method, no error", "{{.EPERM false}}", "false", tVal, true},
 }
 
+func gopher() string {
+       return "gopher"
+}
+
 func testExecute(execTests []execTest, set *Set, t *testing.T) {
        b := new(bytes.Buffer)
+       funcs := FuncMap{"gopher": gopher}
        for _, test := range execTests {
-               tmpl := New(test.name)
+               tmpl := New(test.name).Funcs(funcs)
                err := tmpl.Parse(test.input)
                if err != nil {
                        t.Errorf("%s: parse error: %s", test.name, err)
diff --git a/src/pkg/exp/template/funcs.go b/src/pkg/exp/template/funcs.go
new file mode 100644 (file)
index 0000000..88f82f3
--- /dev/null
@@ -0,0 +1,63 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package template
+
+import (
+       "fmt"
+       "reflect"
+)
+
+// FuncMap is the type of the map defining the mapping from names to functions.
+// Each function must have either a single return value, or two return values of
+// which the second has type os.Error.
+type FuncMap map[string]interface{}
+
+var funcs = map[string]reflect.Value{
+       "printf": reflect.ValueOf(fmt.Sprintf),
+}
+
+// addFuncs adds to values the functions in funcs, converting them to reflect.Values.
+func addFuncs(values map[string]reflect.Value, funcMap FuncMap) {
+       for name, fn := range funcMap {
+               v := reflect.ValueOf(fn)
+               if v.Kind() != reflect.Func {
+                       panic("value for " + name + " not a function")
+               }
+               if !goodFunc(v.Type()) {
+                       panic(fmt.Errorf("can't handle multiple results from method/function %q", name))
+               }
+               values[name] = v
+       }
+}
+
+// goodFunc checks that the function or method has the right result signature.
+func goodFunc(typ reflect.Type) bool {
+       // We allow functions with 1 result or 2 results where the second is an os.Error.
+       switch {
+       case typ.NumOut() == 1:
+               return true
+       case typ.NumOut() == 2 && typ.Out(1) == osErrorType:
+               return true
+       }
+       return false
+}
+
+// findFunction looks for a function in the template, set, and global map.
+func findFunction(name string, tmpl *Template, set *Set) (reflect.Value, bool) {
+       if tmpl != nil {
+               if fn := tmpl.funcs[name]; fn.IsValid() {
+                       return fn, true
+               }
+       }
+       if set != nil {
+               if fn := set.funcs[name]; fn.IsValid() {
+                       return fn, true
+               }
+       }
+       if fn := funcs[name]; fn.IsValid() {
+               return fn, true
+       }
+       return reflect.Value{}, false
+}
index 74b5f2c0aebfee32f64121f0aa449770214c8122..aaed411d495e0b88cc82dd84227b3ae7bd04d704 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bytes"
        "fmt"
        "os"
+       "reflect"
        "runtime"
        "strconv"
        "strings"
@@ -16,8 +17,9 @@ import (
 
 // Template is the representation of a parsed template.
 type Template struct {
-       name string
-       root *listNode
+       name  string
+       root  *listNode
+       funcs map[string]reflect.Value
        // Parsing.
        lex      *lexer
        tokens   <-chan item
@@ -450,12 +452,23 @@ func (w *withNode) String() string {
 // New allocates a new template with the given name.
 func New(name string) *Template {
        return &Template{
-               name: name,
+               name:  name,
+               funcs: make(map[string]reflect.Value),
        }
 }
 
+// Funcs adds to the template's function map the elements of the
+// argument map.   It panics if a value in the map is not a function
+// with appropriate return type.
+// The return value is the template, so calls can be chained.
+func (t *Template) Funcs(funcMap FuncMap) *Template {
+       addFuncs(t.funcs, funcMap)
+       return t
+}
+
 // errorf formats the error and terminates processing.
 func (t *Template) errorf(format string, args ...interface{}) {
+       t.root = nil
        format = fmt.Sprintf("template: %s:%d: %s", t.name, t.lex.lineNumber(), format)
        panic(fmt.Errorf(format, args...))
 }
@@ -488,7 +501,6 @@ func (t *Template) recover(errp *os.Error) {
                        panic(e)
                }
                t.stopParse()
-               t.root = nil
                *errp = e.(os.Error)
        }
        return
index 13d93d03ca5f95a89425a5e9922666e78099d250..3aaabaad5a98096a41c05b9b122dec46d6471d5e 100644 (file)
@@ -6,6 +6,7 @@ package template
 
 import (
        "os"
+       "reflect"
        "runtime"
        "strconv"
 )
@@ -13,16 +14,27 @@ import (
 // Set holds a set of related templates that can refer to one another by name.
 // A template may be a member of multiple sets.
 type Set struct {
-       tmpl map[string]*Template
+       tmpl  map[string]*Template
+       funcs map[string]reflect.Value
 }
 
 // NewSet allocates a new, empty template set.
 func NewSet() *Set {
        return &Set{
-               tmpl: make(map[string]*Template),
+               tmpl:  make(map[string]*Template),
+               funcs: make(map[string]reflect.Value),
        }
 }
 
+// Funcs adds to the set's function map the elements of the
+// argument map.   It panics if a value in the map is not a function
+// with appropriate return type.
+// The return value is the set, so calls can be chained.
+func (s *Set) Funcs(funcMap FuncMap) *Set {
+       addFuncs(s.funcs, funcMap)
+       return s
+}
+
 // recover is the handler that turns panics into returns from the top
 // level of Parse.
 func (s *Set) recover(errp *os.Error) {