]> Cypherpunks repositories - gostls13.git/commitdiff
text/template: implement comparison of basic types
authorRob Pike <r@golang.org>
Wed, 21 Aug 2013 01:27:27 +0000 (11:27 +1000)
committerRob Pike <r@golang.org>
Wed, 21 Aug 2013 01:27:27 +0000 (11:27 +1000)
Add eq, lt, etc. to allow one to do simple comparisons.
It's basic types only (booleans, integers, unsigned integers,
floats, complex, string) because that's easy, easy to define,
and covers the great majority of useful cases, while leaving
open the possibility of a more sweeping definition later.

{{if eq .X .Y}}X and Y are equal{{else}}X and Y are unequal{{end}}

R=golang-dev, adg
CC=golang-dev
https://golang.org/cl/13091045

src/pkg/text/template/doc.go
src/pkg/text/template/exec_test.go
src/pkg/text/template/funcs.go

index c9121f74d3577728111ebd390913fa1fcb4dd356..b952789d1cd7a310ea86d55af6ac00f59a6dcb0b 100644 (file)
@@ -301,8 +301,32 @@ Predefined global functions are named as follows.
                Returns the escaped value of the textual representation of
                its arguments in a form suitable for embedding in a URL query.
 
-The boolean functions take any zero value to be false and a non-zero value to
-be true.
+The boolean functions take any zero value to be false and a non-zero
+value to be true.
+
+There is also a set of binary comparison operators defined as
+functions:
+
+       eq
+               Returns the boolean truth of arg1 == arg2
+       ne
+               Returns the boolean truth of arg1 != arg2
+       lt
+               Returns the boolean truth of arg1 < arg2
+       le
+               Returns the boolean truth of arg1 <= arg2
+       gt
+               Returns the boolean truth of arg1 > arg2
+       ge
+               Returns the boolean truth of arg1 >= arg2
+
+These functions work on basic types only (or named basic types,
+such as "type Celsius float32"). They implement the Go rules for
+comparison of values, except that size and exact type are ignored,
+so any integer value may be compared with any other integer value,
+any unsigned integer value may be compared with any other unsigned
+integer value, and so on. However, as usual, one may not compare
+an int with a float32 and so on.
 
 Associated templates
 
index 3d110af9cce18a8c60f17acf853a82ca4ca6a109..341c5021738bafdb5ab30c2bc4d6244c6b7b3eb2 100644 (file)
@@ -863,3 +863,110 @@ func TestMessageForExecuteEmpty(t *testing.T) {
                t.Fatal(err)
        }
 }
+
+type cmpTest struct {
+       expr  string
+       truth string
+       ok    bool
+}
+
+var cmpTests = []cmpTest{
+       {"eq true true", "true", true},
+       {"eq true false", "false", true},
+       {"eq 1+2i 1+2i", "true", true},
+       {"eq 1+2i 1+3i", "false", true},
+       {"eq 1.5 1.5", "true", true},
+       {"eq 1.5 2.5", "false", true},
+       {"eq 1 1", "true", true},
+       {"eq 1 2", "false", true},
+       {"eq `xy` `xy`", "true", true},
+       {"eq `xy` `xyz`", "false", true},
+       {"eq .Xuint .Xuint", "true", true},
+       {"eq .Xuint .Yuint", "false", true},
+       {"ne true true", "false", true},
+       {"ne true false", "true", true},
+       {"ne 1+2i 1+2i", "false", true},
+       {"ne 1+2i 1+3i", "true", true},
+       {"ne 1.5 1.5", "false", true},
+       {"ne 1.5 2.5", "true", true},
+       {"ne 1 1", "false", true},
+       {"ne 1 2", "true", true},
+       {"ne `xy` `xy`", "false", true},
+       {"ne `xy` `xyz`", "true", true},
+       {"ne .Xuint .Xuint", "false", true},
+       {"ne .Xuint .Yuint", "true", true},
+       {"lt 1.5 1.5", "false", true},
+       {"lt 1.5 2.5", "true", true},
+       {"lt 1 1", "false", true},
+       {"lt 1 2", "true", true},
+       {"lt `xy` `xy`", "false", true},
+       {"lt `xy` `xyz`", "true", true},
+       {"lt .Xuint .Xuint", "false", true},
+       {"lt .Xuint .Yuint", "true", true},
+       {"le 1.5 1.5", "true", true},
+       {"le 1.5 2.5", "true", true},
+       {"le 2.5 1.5", "false", true},
+       {"le 1 1", "true", true},
+       {"le 1 2", "true", true},
+       {"le 2 1", "false", true},
+       {"le `xy` `xy`", "true", true},
+       {"le `xy` `xyz`", "true", true},
+       {"le `xyz` `xy`", "false", true},
+       {"le .Xuint .Xuint", "true", true},
+       {"le .Xuint .Yuint", "true", true},
+       {"le .Yuint .Xuint", "false", true},
+       {"gt 1.5 1.5", "false", true},
+       {"gt 1.5 2.5", "false", true},
+       {"gt 1 1", "false", true},
+       {"gt 2 1", "true", true},
+       {"gt 1 2", "false", true},
+       {"gt `xy` `xy`", "false", true},
+       {"gt `xy` `xyz`", "false", true},
+       {"gt .Xuint .Xuint", "false", true},
+       {"gt .Xuint .Yuint", "false", true},
+       {"gt .Yuint .Xuint", "true", true},
+       {"ge 1.5 1.5", "true", true},
+       {"ge 1.5 2.5", "false", true},
+       {"ge 2.5 1.5", "true", true},
+       {"ge 1 1", "true", true},
+       {"ge 1 2", "false", true},
+       {"ge 2 1", "true", true},
+       {"ge `xy` `xy`", "true", true},
+       {"ge `xy` `xyz`", "false", true},
+       {"ge `xyz` `xy`", "true", true},
+       {"ge .Xuint .Xuint", "true", true},
+       {"ge .Xuint .Yuint", "false", true},
+       {"ge .Yuint .Xuint", "true", true},
+       // Errors
+       {"eq 3 4 5", "", false},     // Too many arguments.
+       {"eq `xy` 1", "", false},    // Different types.
+       {"lt true true", "", false}, // Unordered types.
+       {"lt 1+0i 1+0i", "", false}, // Unordered types.
+}
+
+func TestComparison(t *testing.T) {
+       b := new(bytes.Buffer)
+       var cmpStruct = struct {
+               Xuint, Yuint uint
+       }{3, 4}
+       for _, test := range cmpTests {
+               text := fmt.Sprintf("{{if %s}}true{{else}}false{{end}}", test.expr)
+               tmpl, err := New("empty").Parse(text)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               b.Reset()
+               err = tmpl.Execute(b, &cmpStruct)
+               if test.ok && err != nil {
+                       t.Errorf("%s errored incorrectly: %s", test.expr, err)
+                       continue
+               }
+               if !test.ok && err == nil {
+                       t.Errorf("%s did not error")
+                       continue
+               }
+               if b.String() != test.truth {
+                       t.Errorf("%s: want %s; got %s", test.expr, test.truth, b.String())
+               }
+       }
+}
index 643a728cb7a4ac80d4f8e7da4ac6784cb83cd344..9402170bd07d7d9fcb430d11219ce0a51dd79462 100644 (file)
@@ -6,6 +6,7 @@ package template
 
 import (
        "bytes"
+       "errors"
        "fmt"
        "io"
        "net/url"
@@ -35,6 +36,14 @@ var builtins = FuncMap{
        "printf":   fmt.Sprintf,
        "println":  fmt.Sprintln,
        "urlquery": URLQueryEscaper,
+
+       // Comparisons
+       "eq": eq, // ==
+       "ge": ge, // >=
+       "gt": gt, // >
+       "le": le, // <=
+       "lt": lt, // <
+       "ne": ne, // !=
 }
 
 var builtinFuncs = createValueFuncs(builtins)
@@ -248,6 +257,151 @@ func not(arg interface{}) (truth bool) {
        return !truth
 }
 
+// Comparison.
+
+// TODO: Perhaps allow comparison between signed and unsigned integers.
+
+var (
+       errBadComparisonType = errors.New("invalid type for comparison")
+       errBadComparison     = errors.New("incompatible types for comparison")
+)
+
+type kind int
+
+const (
+       invalidKind kind = iota
+       boolKind
+       complexKind
+       intKind
+       floatKind
+       integerKind
+       stringKind
+       uintKind
+)
+
+func basicKind(v reflect.Value) (kind, error) {
+       switch v.Kind() {
+       case reflect.Bool:
+               return boolKind, nil
+       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+               return intKind, nil
+       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+               return uintKind, nil
+       case reflect.Float32, reflect.Float64:
+               return floatKind, nil
+       case reflect.Complex64, reflect.Complex128:
+               return complexKind, nil
+       case reflect.String:
+               return stringKind, nil
+       }
+       return invalidKind, errBadComparisonType
+}
+
+// eq evaluates the comparison a == b.
+func eq(arg1, arg2 interface{}) (bool, error) {
+       v1 := reflect.ValueOf(arg1)
+       k1, err := basicKind(v1)
+       if err != nil {
+               return false, err
+       }
+       v2 := reflect.ValueOf(arg2)
+       k2, err := basicKind(v2)
+       if err != nil {
+               return false, err
+       }
+       if k1 != k2 {
+               return false, errBadComparison
+       }
+       truth := false
+       switch k1 {
+       case boolKind:
+               truth = v1.Bool() == v2.Bool()
+       case complexKind:
+               truth = v1.Complex() == v2.Complex()
+       case floatKind:
+               truth = v1.Float() == v2.Float()
+       case intKind:
+               truth = v1.Int() == v2.Int()
+       case stringKind:
+               truth = v1.String() == v2.String()
+       case uintKind:
+               truth = v1.Uint() == v2.Uint()
+       default:
+               panic("invalid kind")
+       }
+       return truth, nil
+}
+
+// ne evaluates the comparison a != b.
+func ne(arg1, arg2 interface{}) (bool, error) {
+       // != is the inverse of ==.
+       equal, err := eq(arg1, arg2)
+       return !equal, err
+}
+
+// lt evaluates the comparison a < b.
+func lt(arg1, arg2 interface{}) (bool, error) {
+       v1 := reflect.ValueOf(arg1)
+       k1, err := basicKind(v1)
+       if err != nil {
+               return false, err
+       }
+       v2 := reflect.ValueOf(arg2)
+       k2, err := basicKind(v2)
+       if err != nil {
+               return false, err
+       }
+       if k1 != k2 {
+               return false, errBadComparison
+       }
+       truth := false
+       switch k1 {
+       case boolKind, complexKind:
+               return false, errBadComparisonType
+       case floatKind:
+               truth = v1.Float() < v2.Float()
+       case intKind:
+               truth = v1.Int() < v2.Int()
+       case stringKind:
+               truth = v1.String() < v2.String()
+       case uintKind:
+               truth = v1.Uint() < v2.Uint()
+       default:
+               panic("invalid kind")
+       }
+       return truth, nil
+}
+
+// le evaluates the comparison <= b.
+func le(arg1, arg2 interface{}) (bool, error) {
+       // <= is < or ==.
+       lessThan, err := lt(arg1, arg2)
+       if lessThan || err != nil {
+               return lessThan, err
+       }
+       return eq(arg1, arg2)
+}
+
+// gt evaluates the comparison a > b.
+func gt(arg1, arg2 interface{}) (bool, error) {
+       // > is the inverse of <=.
+       lessOrEqual, err := le(arg1, arg2)
+       if err != nil {
+               return false, err
+       }
+       return !lessOrEqual, nil
+}
+
+// ge evaluates the comparison a >= b.
+func ge(arg1, arg2 interface{}) (bool, error) {
+       // >= is the inverse of <.
+       lessThan, err := lt(arg1, arg2)
+       if err != nil {
+               return false, err
+       }
+       return !lessThan, nil
+}
+
 // HTML escaping.
 
 var (