osopen.go\
httpserver.go\
procattr.go\
+ reflect.go\
+ typecheck.go\
include ../../Make.cmd
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// TODO(rsc): Once there is better support for writing
+// multi-package commands, this should really be in
+// its own package, and then we can drop all the "reflect"
+// prefixes on the global variables and functions.
+
+package main
+
+import (
+ "go/ast"
+ "go/token"
+ "strings"
+)
+
+var reflectFix = fix{
+ "reflect",
+ reflectFn,
+ `Adapt code to new reflect API.
+
+http://codereview.appspot.com/4281055
+`,
+}
+
+func init() {
+ register(reflectFix)
+}
+
+// The reflect API change dropped the concrete types *reflect.ArrayType etc.
+// Any type assertions prior to method calls can be deleted:
+// x.(*reflect.ArrayType).Len() -> x.Len()
+//
+// Any type checks can be replaced by assignment and check of Kind:
+// x, y := z.(*reflect.ArrayType)
+// ->
+// x := z
+// y := x.Kind() == reflect.Array
+//
+// If z is an ordinary variable name and x is not subsequently assigned to,
+// references to x can be replaced by z and the assignment deleted.
+// We only bother if x and z are the same name.
+// If y is not subsequently assigned to and neither is x, references to
+// y can be replaced by its expression. We only bother when there is
+// just one use or when the use appears in an if clause.
+//
+// Not all type checks result in a single Kind check. The rewrite of the type check for
+// reflect.ArrayOrSliceType checks x.Kind() against reflect.Array and reflect.Slice.
+// The rewrite for *reflect.IntType checks againt Int, Int8, Int16, Int32, Int64.
+// The rewrite for *reflect.UintType adds Uintptr.
+//
+// A type switch turns into an assignment and a switch on Kind:
+// switch x := y.(type) {
+// case reflect.ArrayOrSliceType:
+// ...
+// case *reflect.ChanType:
+// ...
+// case *reflect.IntType:
+// ...
+// }
+// ->
+// switch x := y; x.Kind() {
+// case reflect.Array, reflect.Slice:
+// ...
+// case reflect.Chan:
+// ...
+// case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+// ...
+// }
+//
+// The same simplification applies: we drop x := x if x is not assigned
+// to in the switch cases.
+//
+// Because the type check assignment includes a type assertion in its
+// syntax and the rewrite traversal is bottom up, we must do a pass to
+// rewrite the type check assignments and then a separate pass to
+// rewrite the type assertions.
+//
+// The same process applies to the API changes for reflect.Value.
+//
+// For both cases, but especially Value, the code needs to be aware
+// of the type of a receiver when rewriting a method call. For example,
+// x.(*reflect.ArrayValue).Elem(i) becomes x.Index(i) while
+// x.(*reflect.MapValue).Elem(v) becomes x.MapIndex(v).
+// In general, reflectFn needs to know the type of the receiver expression.
+// In most cases (and in all the cases in the Go source tree), the toy
+// type checker in typecheck.go provides enough information for gofix
+// to make the rewrite. If gofix misses a rewrite, the code that is left over
+// will not compile, so it will be noticed immediately.
+
+func reflectFn(f *ast.File) bool {
+ if !imports(f, "reflect") {
+ return false
+ }
+
+ fixed := false
+
+ // Rewrite names in method calls.
+ // Needs basic type information (see above).
+ typeof := typecheck(reflectTypeConfig, f)
+ walk(f, func(n interface{}) {
+ switch n := n.(type) {
+ case *ast.SelectorExpr:
+ typ := typeof[n.X]
+ if m := reflectRewriteMethod[typ]; m != nil {
+ if replace := m[n.Sel.Name]; replace != "" {
+ n.Sel.Name = replace
+ fixed = true
+ return
+ }
+ }
+
+ // For all reflect Values, replace SetValue with Set.
+ if isReflectValue[typ] && n.Sel.Name == "SetValue" {
+ n.Sel.Name = "Set"
+ fixed = true
+ return
+ }
+
+ // Replace reflect.MakeZero with reflect.Zero.
+ if isPkgDot(n, "reflect", "MakeZero") {
+ n.Sel.Name = "Zero"
+ fixed = true
+ return
+ }
+ }
+ })
+
+ // Replace PtrValue's PointTo(x) with Set(x.Addr()).
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok || len(call.Args) != 1 {
+ return
+ }
+ sel, ok := call.Fun.(*ast.SelectorExpr)
+ if !ok || sel.Sel.Name != "PointTo" {
+ return
+ }
+ typ := typeof[sel.X]
+ if typ != "*reflect.PtrValue" {
+ return
+ }
+ sel.Sel.Name = "Set"
+ if !isTopName(call.Args[0], "nil") {
+ call.Args[0] = &ast.SelectorExpr{
+ X: call.Args[0],
+ Sel: ast.NewIdent("Addr()"),
+ }
+ }
+ fixed = true
+ })
+
+ // Fix type switches.
+ walk(f, func(n interface{}) {
+ if reflectFixSwitch(n) {
+ fixed = true
+ }
+ })
+
+ // Fix type assertion checks (multiple assignment statements).
+ // Have to work on the statement context (statement list or if statement)
+ // so that we can insert an extra statement occasionally.
+ // Ignoring for and switch because they don't come up in
+ // typical code.
+ walk(f, func(n interface{}) {
+ switch n := n.(type) {
+ case *[]ast.Stmt:
+ // v is the replacement statement list.
+ var v []ast.Stmt
+ insert := func(x ast.Stmt) {
+ v = append(v, x)
+ }
+ for i, x := range *n {
+ // Tentatively append to v; if we rewrite x
+ // we'll have to update the entry, so remember
+ // the index.
+ j := len(v)
+ v = append(v, x)
+ if reflectFixTypecheck(&x, insert, (*n)[i+1:]) {
+ // reflectFixTypecheck may have overwritten x.
+ // Update the entry we appended just before the call.
+ v[j] = x
+ fixed = true
+ }
+ }
+ *n = v
+ case *ast.IfStmt:
+ x := &ast.ExprStmt{n.Cond}
+ if reflectFixTypecheck(&n.Init, nil, []ast.Stmt{x, n.Body, n.Else}) {
+ n.Cond = x.X
+ fixed = true
+ }
+ }
+ })
+
+ // Warn about any typecheck statements that we missed.
+ walk(f, reflectWarnTypecheckStmt)
+
+ // Now that those are gone, fix remaining type assertions.
+ // Delayed because the type checks have
+ // type assertions as part of their syntax.
+ walk(f, func(n interface{}) {
+ if reflectFixAssert(n) {
+ fixed = true
+ }
+ })
+
+ // Now that the type assertions are gone, rewrite remaining
+ // references to specific reflect types to use the general ones.
+ walk(f, func(n interface{}) {
+ ptr, ok := n.(*ast.Expr)
+ if !ok {
+ return
+ }
+ nn := *ptr
+ typ := reflectType(nn)
+ if typ == "" {
+ return
+ }
+ if strings.HasSuffix(typ, "Type") {
+ *ptr = newPkgDot(nn.Pos(), "reflect", "Type")
+ } else {
+ *ptr = newPkgDot(nn.Pos(), "reflect", "Value")
+ }
+ fixed = true
+ })
+
+ // Rewrite v.Set(nil) to v.Set(reflect.MakeZero(v.Type())).
+ walk(f, func(n interface{}) {
+ call, ok := n.(*ast.CallExpr)
+ if !ok || len(call.Args) != 1 || !isTopName(call.Args[0], "nil") {
+ return
+ }
+ sel, ok := call.Fun.(*ast.SelectorExpr)
+ if !ok || !isReflectValue[typeof[sel.X]] || sel.Sel.Name != "Set" {
+ return
+ }
+ call.Args[0] = &ast.CallExpr{
+ Fun: newPkgDot(call.Args[0].Pos(), "reflect", "Zero"),
+ Args: []ast.Expr{
+ &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: sel.X,
+ Sel: &ast.Ident{Name: "Type"},
+ },
+ },
+ },
+ }
+ fixed = true
+ })
+
+ // Rewrite v != nil to v.IsValid().
+ // Rewrite nil used as reflect.Value (in function argument or return) to reflect.Value{}.
+ walk(f, func(n interface{}) {
+ ptr, ok := n.(*ast.Expr)
+ if !ok {
+ return
+ }
+ if isTopName(*ptr, "nil") && isReflectValue[typeof[*ptr]] {
+ *ptr = ast.NewIdent("reflect.Value{}")
+ fixed = true
+ return
+ }
+ nn, ok := (*ptr).(*ast.BinaryExpr)
+ if !ok || (nn.Op != token.EQL && nn.Op != token.NEQ) || !isTopName(nn.Y, "nil") || !isReflectValue[typeof[nn.X]] {
+ return
+ }
+ var call ast.Expr = &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: nn.X,
+ Sel: &ast.Ident{Name: "IsValid"},
+ },
+ }
+ if nn.Op == token.EQL {
+ call = &ast.UnaryExpr{Op: token.NOT, X: call}
+ }
+ *ptr = call
+ fixed = true
+ })
+
+ return fixed
+}
+
+// reflectFixSwitch rewrites *n (if n is an *ast.Stmt) corresponding
+// to a type switch.
+func reflectFixSwitch(n interface{}) bool {
+ ptr, ok := n.(*ast.Stmt)
+ if !ok {
+ return false
+ }
+ n = *ptr
+
+ ts, ok := n.(*ast.TypeSwitchStmt)
+ if !ok {
+ return false
+ }
+
+ // Are any switch cases referring to reflect types?
+ // (That is, is this an old reflect type switch?)
+ for _, cas := range ts.Body.List {
+ for _, typ := range cas.(*ast.CaseClause).List {
+ if reflectType(typ) != "" {
+ goto haveReflect
+ }
+ }
+ }
+ return false
+
+haveReflect:
+ // Now we know it's an old reflect type switch. Prepare the new version,
+ // but don't replace or edit the original until we're sure of success.
+
+ // Figure out the initializer statement, if any, and the receiver for the Kind call.
+ var init ast.Stmt
+ var rcvr ast.Expr
+
+ init = ts.Init
+ switch n := ts.Assign.(type) {
+ default:
+ warn(ts.Pos(), "unexpected form in type switch")
+ return false
+
+ case *ast.AssignStmt:
+ as := n
+ ta := as.Rhs[0].(*ast.TypeAssertExpr)
+ x := isIdent(as.Lhs[0])
+ z := isIdent(ta.X)
+
+ if isBlank(x) || x != nil && z != nil && x.Name == z.Name && !assignsTo(x, ts.Body.List) {
+ // Can drop the variable creation.
+ rcvr = ta.X
+ } else {
+ // Need to use initialization statement.
+ if init != nil {
+ warn(ts.Pos(), "cannot rewrite reflect type switch with initializing statement")
+ return false
+ }
+ init = &ast.AssignStmt{
+ Lhs: []ast.Expr{as.Lhs[0]},
+ TokPos: as.TokPos,
+ Tok: token.DEFINE,
+ Rhs: []ast.Expr{ta.X},
+ }
+ rcvr = as.Lhs[0]
+ }
+
+ case *ast.ExprStmt:
+ rcvr = n.X.(*ast.TypeAssertExpr).X
+ }
+
+ // Prepare rewritten type switch (see large comment above for form).
+ sw := &ast.SwitchStmt{
+ Switch: ts.Switch,
+ Init: init,
+ Tag: &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: rcvr,
+ Sel: &ast.Ident{
+ NamePos: rcvr.End(),
+ Name: "Kind",
+ Obj: nil,
+ },
+ },
+ Lparen: rcvr.End(),
+ Rparen: rcvr.End(),
+ },
+ Body: &ast.BlockStmt{
+ Lbrace: ts.Body.Lbrace,
+ List: nil, // to be filled in
+ Rbrace: ts.Body.Rbrace,
+ },
+ }
+
+ // Translate cases.
+ for _, tcas := range ts.Body.List {
+ tcas := tcas.(*ast.CaseClause)
+ cas := &ast.CaseClause{
+ Case: tcas.Case,
+ Colon: tcas.Colon,
+ Body: tcas.Body,
+ }
+ for _, t := range tcas.List {
+ if isTopName(t, "nil") {
+ cas.List = append(cas.List, newPkgDot(t.Pos(), "reflect", "Invalid"))
+ continue
+ }
+
+ typ := reflectType(t)
+ if typ == "" {
+ warn(t.Pos(), "cannot rewrite reflect type switch case with non-reflect type %s", gofmt(t))
+ cas.List = append(cas.List, t)
+ continue
+ }
+
+ for _, k := range reflectKind[typ] {
+ cas.List = append(cas.List, newPkgDot(t.Pos(), "reflect", k))
+ }
+ }
+ sw.Body.List = append(sw.Body.List, cas)
+ }
+
+ // Everything worked. Rewrite AST.
+ *ptr = sw
+ return true
+}
+
+// Rewrite x, y = z.(T) into
+// x = z
+// y = x.Kind() == K
+// as described in the long comment above.
+//
+// If insert != nil, it can be called to insert a statement after *ptr in its block.
+// If insert == nil, insertion is not possible.
+// At most one call to insert is allowed.
+//
+// Scope gives the statements for which a declaration
+// in *ptr would be in scope.
+//
+// The result is true of the statement was rewritten.
+//
+func reflectFixTypecheck(ptr *ast.Stmt, insert func(ast.Stmt), scope []ast.Stmt) bool {
+ st := *ptr
+ as, ok := st.(*ast.AssignStmt)
+ if !ok || len(as.Lhs) != 2 || len(as.Rhs) != 1 {
+ return false
+ }
+
+ ta, ok := as.Rhs[0].(*ast.TypeAssertExpr)
+ if !ok {
+ return false
+ }
+ typ := reflectType(ta.Type)
+ if typ == "" {
+ return false
+ }
+
+ // Have x, y := z.(t).
+ x := isIdent(as.Lhs[0])
+ y := isIdent(as.Lhs[1])
+ z := isIdent(ta.X)
+
+ // First step is x := z, unless it's x := x and the resulting x is never reassigned.
+ // rcvr is the x in x.Kind().
+ var rcvr ast.Expr
+ if isBlank(x) ||
+ as.Tok == token.DEFINE && x != nil && z != nil && x.Name == z.Name && !assignsTo(x, scope) {
+ // Can drop the statement.
+ // If we need to insert a statement later, now we have a slot.
+ *ptr = &ast.EmptyStmt{}
+ insert = func(x ast.Stmt) { *ptr = x }
+ rcvr = ta.X
+ } else {
+ *ptr = &ast.AssignStmt{
+ Lhs: []ast.Expr{as.Lhs[0]},
+ TokPos: as.TokPos,
+ Tok: as.Tok,
+ Rhs: []ast.Expr{ta.X},
+ }
+ rcvr = as.Lhs[0]
+ }
+
+ // Prepare x.Kind() == T expression appropriate to t.
+ // If x is not a simple identifier, warn that we might be
+ // reevaluating x.
+ if x == nil {
+ warn(as.Pos(), "rewrite reevaluates expr with possible side effects: %s", gofmt(as.Lhs[0]))
+ }
+ yExpr, yNotExpr := reflectKindEq(rcvr, reflectKind[typ])
+
+ // Second step is y := x.Kind() == T, unless it's only used once
+ // or we have no way to insert that statement.
+ var yStmt *ast.AssignStmt
+ if as.Tok == token.DEFINE && countUses(y, scope) <= 1 || insert == nil {
+ // Can drop the statement and use the expression directly.
+ rewriteUses(y,
+ func(token.Pos) ast.Expr { return yExpr },
+ func(token.Pos) ast.Expr { return yNotExpr },
+ scope)
+ } else {
+ yStmt = &ast.AssignStmt{
+ Lhs: []ast.Expr{as.Lhs[1]},
+ TokPos: as.End(),
+ Tok: as.Tok,
+ Rhs: []ast.Expr{yExpr},
+ }
+ insert(yStmt)
+ }
+ return true
+}
+
+// reflectKindEq returns the expression z.Kind() == kinds[0] || z.Kind() == kinds[1] || ...
+// and its negation.
+// The qualifier "reflect." is inserted before each kinds[i] expression.
+func reflectKindEq(z ast.Expr, kinds []string) (ast.Expr, ast.Expr) {
+ n := len(kinds)
+ if n == 1 {
+ y := &ast.BinaryExpr{
+ X: &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: z,
+ Sel: ast.NewIdent("Kind"),
+ },
+ },
+ Op: token.EQL,
+ Y: newPkgDot(token.NoPos, "reflect", kinds[0]),
+ }
+ ynot := &ast.BinaryExpr{
+ X: &ast.CallExpr{
+ Fun: &ast.SelectorExpr{
+ X: z,
+ Sel: ast.NewIdent("Kind"),
+ },
+ },
+ Op: token.NEQ,
+ Y: newPkgDot(token.NoPos, "reflect", kinds[0]),
+ }
+ return y, ynot
+ }
+
+ x, xnot := reflectKindEq(z, kinds[0:n-1])
+ y, ynot := reflectKindEq(z, kinds[n-1:])
+
+ or := &ast.BinaryExpr{
+ X: x,
+ Op: token.LOR,
+ Y: y,
+ }
+ andnot := &ast.BinaryExpr{
+ X: xnot,
+ Op: token.LAND,
+ Y: ynot,
+ }
+ return or, andnot
+}
+
+// if x represents a known old reflect type/value like *reflect.PtrType or reflect.ArrayOrSliceValue,
+// reflectType returns the string form of that type.
+func reflectType(x ast.Expr) string {
+ ptr, ok := x.(*ast.StarExpr)
+ if ok {
+ x = ptr.X
+ }
+
+ sel, ok := x.(*ast.SelectorExpr)
+ if !ok || !isName(sel.X, "reflect") {
+ return ""
+ }
+
+ var s = "reflect."
+ if ptr != nil {
+ s = "*reflect."
+ }
+ s += sel.Sel.Name
+
+ if reflectKind[s] != nil {
+ return s
+ }
+ return ""
+}
+
+// reflectWarnTypecheckStmt warns about statements
+// of the form x, y = z.(T) for any old reflect type T.
+// The last pass should have gotten them all, and if it didn't,
+// the next pass is going to turn them into x, y = z.
+func reflectWarnTypecheckStmt(n interface{}) {
+ as, ok := n.(*ast.AssignStmt)
+ if !ok || len(as.Lhs) != 2 || len(as.Rhs) != 1 {
+ return
+ }
+ ta, ok := as.Rhs[0].(*ast.TypeAssertExpr)
+ if !ok || reflectType(ta.Type) == "" {
+ return
+ }
+ warn(n.(ast.Node).Pos(), "unfixed reflect type check")
+}
+
+// reflectFixAssert rewrites x.(T) to x for any old reflect type T.
+func reflectFixAssert(n interface{}) bool {
+ ptr, ok := n.(*ast.Expr)
+ if ok {
+ ta, ok := (*ptr).(*ast.TypeAssertExpr)
+ if ok && reflectType(ta.Type) != "" {
+ *ptr = ta.X
+ return true
+ }
+ }
+ return false
+}
+
+// Tables describing the transformations.
+
+// Description of old reflect API for partial type checking.
+// We pretend the Elem method is on Type and Value instead
+// of enumerating all the types it is actually on.
+// Also, we pretend that ArrayType etc embeds Type for the
+// purposes of describing the API. (In fact they embed commonType,
+// which implements Type.)
+var reflectTypeConfig = &TypeConfig{
+ Type: map[string]*Type{
+ "reflect.ArrayOrSliceType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.ArrayOrSliceValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.ArrayType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.ArrayValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.BoolType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.BoolValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.ChanType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.ChanValue": &Type{
+ Method: map[string]string{
+ "Recv": "func() (reflect.Value, bool)",
+ "TryRecv": "func() (reflect.Value, bool)",
+ },
+ Embed: []string{"reflect.Value"},
+ },
+ "reflect.ComplexType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.ComplexValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.FloatType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.FloatValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.FuncType": &Type{
+ Method: map[string]string{
+ "In": "func(int) reflect.Type",
+ "Out": "func(int) reflect.Type",
+ },
+ Embed: []string{"reflect.Type"},
+ },
+ "reflect.FuncValue": &Type{
+ Method: map[string]string{
+ "Call": "func([]reflect.Value) []reflect.Value",
+ },
+ },
+ "reflect.IntType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.IntValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.InterfaceType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.InterfaceValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.MapType": &Type{
+ Method: map[string]string{
+ "Key": "func() reflect.Type",
+ },
+ Embed: []string{"reflect.Type"},
+ },
+ "reflect.MapValue": &Type{
+ Method: map[string]string{
+ "Keys": "func() []reflect.Value",
+ },
+ Embed: []string{"reflect.Value"},
+ },
+ "reflect.Method": &Type{
+ Field: map[string]string{
+ "Type": "*reflect.FuncType",
+ "Func": "*reflect.FuncValue",
+ },
+ },
+ "reflect.PtrType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.PtrValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.SliceType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.SliceValue": &Type{
+ Method: map[string]string{
+ "Slice": "func(int, int) *reflect.SliceValue",
+ },
+ Embed: []string{"reflect.Value"},
+ },
+ "reflect.StringType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.StringValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.StructField": &Type{
+ Field: map[string]string{
+ "Type": "reflect.Type",
+ },
+ },
+ "reflect.StructType": &Type{
+ Method: map[string]string{
+ "Field": "func() reflect.StructField",
+ "FieldByIndex": "func() reflect.StructField",
+ "FieldByName": "func() reflect.StructField,bool",
+ "FieldByNameFunc": "func() reflect.StructField,bool",
+ },
+ Embed: []string{"reflect.Type"},
+ },
+ "reflect.StructValue": &Type{
+ Method: map[string]string{
+ "Field": "func() reflect.Value",
+ "FieldByIndex": "func() reflect.Value",
+ "FieldByName": "func() reflect.Value",
+ "FieldByNameFunc": "func() reflect.Value",
+ },
+ Embed: []string{"reflect.Value"},
+ },
+ "reflect.Type": &Type{
+ Method: map[string]string{
+ "Elem": "func() reflect.Type",
+ "Method": "func() reflect.Method",
+ },
+ },
+ "reflect.UintType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.UintValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.UnsafePointerType": &Type{Embed: []string{"reflect.Type"}},
+ "reflect.UnsafePointerValue": &Type{Embed: []string{"reflect.Value"}},
+ "reflect.Value": &Type{
+ Method: map[string]string{
+ "Addr": "func() *reflect.PtrValue",
+ "Elem": "func() reflect.Value",
+ "Method": "func() *reflect.FuncValue",
+ "SetValue": "func(reflect.Value)",
+ },
+ },
+ },
+ Func: map[string]string{
+ "reflect.Append": "*reflect.SliceValue",
+ "reflect.AppendSlice": "*reflect.SliceValue",
+ "reflect.Indirect": "reflect.Value",
+ "reflect.MakeSlice": "*reflect.SliceValue",
+ "reflect.MakeChan": "*reflect.ChanValue",
+ "reflect.MakeMap": "*reflect.MapValue",
+ "reflect.MakeZero": "reflect.Value",
+ "reflect.NewValue": "reflect.Value",
+ "reflect.PtrTo": "*reflect.PtrType",
+ "reflect.Typeof": "reflect.Type",
+ },
+}
+
+var reflectRewriteMethod = map[string]map[string]string{
+ // The type API didn't change much.
+ "*reflect.ChanType": {"Dir": "ChanDir"},
+ "*reflect.FuncType": {"DotDotDot": "IsVariadic"},
+
+ // The value API has longer names to disambiguate
+ // methods with different signatures.
+ "reflect.ArrayOrSliceValue": { // interface, not pointer
+ "Elem": "Index",
+ },
+ "*reflect.ArrayValue": {
+ "Elem": "Index",
+ },
+ "*reflect.BoolValue": {
+ "Get": "Bool",
+ "Set": "SetBool",
+ },
+ "*reflect.ChanValue": {
+ "Get": "Pointer",
+ },
+ "*reflect.ComplexValue": {
+ "Get": "Complex",
+ "Set": "SetComplex",
+ "Overflow": "OverflowComplex",
+ },
+ "*reflect.FloatValue": {
+ "Get": "Float",
+ "Set": "SetFloat",
+ "Overflow": "OverflowFloat",
+ },
+ "*reflect.FuncValue": {
+ "Get": "Pointer",
+ },
+ "*reflect.IntValue": {
+ "Get": "Int",
+ "Set": "SetInt",
+ "Overflow": "OverflowInt",
+ },
+ "*reflect.InterfaceValue": {
+ "Get": "InterfaceData",
+ },
+ "*reflect.MapValue": {
+ "Elem": "MapIndex",
+ "Get": "Pointer",
+ "Keys": "MapKeys",
+ "SetElem": "SetMapIndex",
+ },
+ "*reflect.PtrValue": {
+ "Get": "Pointer",
+ },
+ "*reflect.SliceValue": {
+ "Elem": "Index",
+ "Get": "Pointer",
+ },
+ "*reflect.StringValue": {
+ "Get": "String",
+ "Set": "SetString",
+ },
+ "*reflect.UintValue": {
+ "Get": "Uint",
+ "Set": "SetUint",
+ "Overflow": "OverflowUint",
+ },
+ "*reflect.UnsafePointerValue": {
+ "Get": "Pointer",
+ "Set": "SetPointer",
+ },
+}
+
+var reflectKind = map[string][]string{
+ "reflect.ArrayOrSliceType": {"Array", "Slice"}, // interface, not pointer
+ "*reflect.ArrayType": {"Array"},
+ "*reflect.BoolType": {"Bool"},
+ "*reflect.ChanType": {"Chan"},
+ "*reflect.ComplexType": {"Complex64", "Complex128"},
+ "*reflect.FloatType": {"Float32", "Float64"},
+ "*reflect.FuncType": {"Func"},
+ "*reflect.IntType": {"Int", "Int8", "Int16", "Int32", "Int64"},
+ "*reflect.InterfaceType": {"Interface"},
+ "*reflect.MapType": {"Map"},
+ "*reflect.PtrType": {"Ptr"},
+ "*reflect.SliceType": {"Slice"},
+ "*reflect.StringType": {"String"},
+ "*reflect.StructType": {"Struct"},
+ "*reflect.UintType": {"Uint", "Uint8", "Uint16", "Uint32", "Uint64", "Uintptr"},
+ "*reflect.UnsafePointerType": {"UnsafePointer"},
+
+ "reflect.ArrayOrSliceValue": {"Array", "Slice"}, // interface, not pointer
+ "*reflect.ArrayValue": {"Array"},
+ "*reflect.BoolValue": {"Bool"},
+ "*reflect.ChanValue": {"Chan"},
+ "*reflect.ComplexValue": {"Complex64", "Complex128"},
+ "*reflect.FloatValue": {"Float32", "Float64"},
+ "*reflect.FuncValue": {"Func"},
+ "*reflect.IntValue": {"Int", "Int8", "Int16", "Int32", "Int64"},
+ "*reflect.InterfaceValue": {"Interface"},
+ "*reflect.MapValue": {"Map"},
+ "*reflect.PtrValue": {"Ptr"},
+ "*reflect.SliceValue": {"Slice"},
+ "*reflect.StringValue": {"String"},
+ "*reflect.StructValue": {"Struct"},
+ "*reflect.UintValue": {"Uint", "Uint8", "Uint16", "Uint32", "Uint64", "Uintptr"},
+ "*reflect.UnsafePointerValue": {"UnsafePointer"},
+}
+
+var isReflectValue = map[string]bool{
+ "reflect.ArrayOrSliceValue": true, // interface, not pointer
+ "*reflect.ArrayValue": true,
+ "*reflect.BoolValue": true,
+ "*reflect.ChanValue": true,
+ "*reflect.ComplexValue": true,
+ "*reflect.FloatValue": true,
+ "*reflect.FuncValue": true,
+ "*reflect.IntValue": true,
+ "*reflect.InterfaceValue": true,
+ "*reflect.MapValue": true,
+ "*reflect.PtrValue": true,
+ "*reflect.SliceValue": true,
+ "*reflect.StringValue": true,
+ "*reflect.StructValue": true,
+ "*reflect.UintValue": true,
+ "*reflect.UnsafePointerValue": true,
+ "reflect.Value": true, // interface, not pointer
+}
--- /dev/null
+package main
+
+import (
+ "io/ioutil"
+ "log"
+ "path/filepath"
+)
+
+func init() {
+ addTestCases(reflectTests())
+}
+
+func reflectTests() []testCase {
+ var tests []testCase
+
+ names, _ := filepath.Glob("testdata/reflect.*.in")
+ for _, in := range names {
+ out := in[:len(in)-len(".in")] + ".out"
+ inb, err := ioutil.ReadFile(in)
+ if err != nil {
+ log.Fatal(err)
+ }
+ outb, err := ioutil.ReadFile(out)
+ if err != nil {
+ log.Fatal(err)
+ }
+ tests = append(tests, testCase{Name: in, In: string(inb), Out: string(outb)})
+ }
+
+ return tests
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The asn1 package implements parsing of DER-encoded ASN.1 data structures,
+// as defined in ITU-T Rec X.690.
+//
+// See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,''
+// http://luca.ntop.org/Teaching/Appunti/asn1.html.
+package asn1
+
+// ASN.1 is a syntax for specifying abstract objects and BER, DER, PER, XER etc
+// are different encoding formats for those objects. Here, we'll be dealing
+// with DER, the Distinguished Encoding Rules. DER is used in X.509 because
+// it's fast to parse and, unlike BER, has a unique encoding for every object.
+// When calculating hashes over objects, it's important that the resulting
+// bytes be the same at both ends and DER removes this margin of error.
+//
+// ASN.1 is very complex and this package doesn't attempt to implement
+// everything by any means.
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+ "time"
+)
+
+// A StructuralError suggests that the ASN.1 data is valid, but the Go type
+// which is receiving it doesn't match.
+type StructuralError struct {
+ Msg string
+}
+
+func (e StructuralError) String() string { return "ASN.1 structure error: " + e.Msg }
+
+// A SyntaxError suggests that the ASN.1 data is invalid.
+type SyntaxError struct {
+ Msg string
+}
+
+func (e SyntaxError) String() string { return "ASN.1 syntax error: " + e.Msg }
+
+// We start by dealing with each of the primitive types in turn.
+
+// BOOLEAN
+
+func parseBool(bytes []byte) (ret bool, err os.Error) {
+ if len(bytes) != 1 {
+ err = SyntaxError{"invalid boolean"}
+ return
+ }
+
+ return bytes[0] != 0, nil
+}
+
+// INTEGER
+
+// parseInt64 treats the given bytes as a big-endian, signed integer and
+// returns the result.
+func parseInt64(bytes []byte) (ret int64, err os.Error) {
+ if len(bytes) > 8 {
+ // We'll overflow an int64 in this case.
+ err = StructuralError{"integer too large"}
+ return
+ }
+ for bytesRead := 0; bytesRead < len(bytes); bytesRead++ {
+ ret <<= 8
+ ret |= int64(bytes[bytesRead])
+ }
+
+ // Shift up and down in order to sign extend the result.
+ ret <<= 64 - uint8(len(bytes))*8
+ ret >>= 64 - uint8(len(bytes))*8
+ return
+}
+
+// parseInt treats the given bytes as a big-endian, signed integer and returns
+// the result.
+func parseInt(bytes []byte) (int, os.Error) {
+ ret64, err := parseInt64(bytes)
+ if err != nil {
+ return 0, err
+ }
+ if ret64 != int64(int(ret64)) {
+ return 0, StructuralError{"integer too large"}
+ }
+ return int(ret64), nil
+}
+
+// BIT STRING
+
+// BitString is the structure to use when you want an ASN.1 BIT STRING type. A
+// bit string is padded up to the nearest byte in memory and the number of
+// valid bits is recorded. Padding bits will be zero.
+type BitString struct {
+ Bytes []byte // bits packed into bytes.
+ BitLength int // length in bits.
+}
+
+// At returns the bit at the given index. If the index is out of range it
+// returns false.
+func (b BitString) At(i int) int {
+ if i < 0 || i >= b.BitLength {
+ return 0
+ }
+ x := i / 8
+ y := 7 - uint(i%8)
+ return int(b.Bytes[x]>>y) & 1
+}
+
+// RightAlign returns a slice where the padding bits are at the beginning. The
+// slice may share memory with the BitString.
+func (b BitString) RightAlign() []byte {
+ shift := uint(8 - (b.BitLength % 8))
+ if shift == 8 || len(b.Bytes) == 0 {
+ return b.Bytes
+ }
+
+ a := make([]byte, len(b.Bytes))
+ a[0] = b.Bytes[0] >> shift
+ for i := 1; i < len(b.Bytes); i++ {
+ a[i] = b.Bytes[i-1] << (8 - shift)
+ a[i] |= b.Bytes[i] >> shift
+ }
+
+ return a
+}
+
+// parseBitString parses an ASN.1 bit string from the given byte array and returns it.
+func parseBitString(bytes []byte) (ret BitString, err os.Error) {
+ if len(bytes) == 0 {
+ err = SyntaxError{"zero length BIT STRING"}
+ return
+ }
+ paddingBits := int(bytes[0])
+ if paddingBits > 7 ||
+ len(bytes) == 1 && paddingBits > 0 ||
+ bytes[len(bytes)-1]&((1<<bytes[0])-1) != 0 {
+ err = SyntaxError{"invalid padding bits in BIT STRING"}
+ return
+ }
+ ret.BitLength = (len(bytes)-1)*8 - paddingBits
+ ret.Bytes = bytes[1:]
+ return
+}
+
+// OBJECT IDENTIFIER
+
+// An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER.
+type ObjectIdentifier []int
+
+// Equal returns true iff oi and other represent the same identifier.
+func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
+ if len(oi) != len(other) {
+ return false
+ }
+ for i := 0; i < len(oi); i++ {
+ if oi[i] != other[i] {
+ return false
+ }
+ }
+
+ return true
+}
+
+// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and
+// returns it. An object identifer is a sequence of variable length integers
+// that are assigned in a hierarachy.
+func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
+ if len(bytes) == 0 {
+ err = SyntaxError{"zero length OBJECT IDENTIFIER"}
+ return
+ }
+
+ // In the worst case, we get two elements from the first byte (which is
+ // encoded differently) and then every varint is a single byte long.
+ s = make([]int, len(bytes)+1)
+
+ // The first byte is 40*value1 + value2:
+ s[0] = int(bytes[0]) / 40
+ s[1] = int(bytes[0]) % 40
+ i := 2
+ for offset := 1; offset < len(bytes); i++ {
+ var v int
+ v, offset, err = parseBase128Int(bytes, offset)
+ if err != nil {
+ return
+ }
+ s[i] = v
+ }
+ s = s[0:i]
+ return
+}
+
+// ENUMERATED
+
+// An Enumerated is represented as a plain int.
+type Enumerated int
+
+
+// FLAG
+
+// A Flag accepts any data and is set to true if present.
+type Flag bool
+
+// parseBase128Int parses a base-128 encoded int from the given offset in the
+// given byte array. It returns the value and the new offset.
+func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) {
+ offset = initOffset
+ for shifted := 0; offset < len(bytes); shifted++ {
+ if shifted > 4 {
+ err = StructuralError{"base 128 integer too large"}
+ return
+ }
+ ret <<= 7
+ b := bytes[offset]
+ ret |= int(b & 0x7f)
+ offset++
+ if b&0x80 == 0 {
+ return
+ }
+ }
+ err = SyntaxError{"truncated base 128 integer"}
+ return
+}
+
+// UTCTime
+
+func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) {
+ s := string(bytes)
+ ret, err = time.Parse("0601021504Z0700", s)
+ if err == nil {
+ return
+ }
+ ret, err = time.Parse("060102150405Z0700", s)
+ return
+}
+
+// parseGeneralizedTime parses the GeneralizedTime from the given byte array
+// and returns the resulting time.
+func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) {
+ return time.Parse("20060102150405Z0700", string(bytes))
+}
+
+// PrintableString
+
+// parsePrintableString parses a ASN.1 PrintableString from the given byte
+// array and returns it.
+func parsePrintableString(bytes []byte) (ret string, err os.Error) {
+ for _, b := range bytes {
+ if !isPrintable(b) {
+ err = SyntaxError{"PrintableString contains invalid character"}
+ return
+ }
+ }
+ ret = string(bytes)
+ return
+}
+
+// isPrintable returns true iff the given b is in the ASN.1 PrintableString set.
+func isPrintable(b byte) bool {
+ return 'a' <= b && b <= 'z' ||
+ 'A' <= b && b <= 'Z' ||
+ '0' <= b && b <= '9' ||
+ '\'' <= b && b <= ')' ||
+ '+' <= b && b <= '/' ||
+ b == ' ' ||
+ b == ':' ||
+ b == '=' ||
+ b == '?' ||
+ // This is techincally not allowed in a PrintableString.
+ // However, x509 certificates with wildcard strings don't
+ // always use the correct string type so we permit it.
+ b == '*'
+}
+
+// IA5String
+
+// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given
+// byte array and returns it.
+func parseIA5String(bytes []byte) (ret string, err os.Error) {
+ for _, b := range bytes {
+ if b >= 0x80 {
+ err = SyntaxError{"IA5String contains invalid character"}
+ return
+ }
+ }
+ ret = string(bytes)
+ return
+}
+
+// T61String
+
+// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given
+// byte array and returns it.
+func parseT61String(bytes []byte) (ret string, err os.Error) {
+ return string(bytes), nil
+}
+
+// A RawValue represents an undecoded ASN.1 object.
+type RawValue struct {
+ Class, Tag int
+ IsCompound bool
+ Bytes []byte
+ FullBytes []byte // includes the tag and length
+}
+
+// RawContent is used to signal that the undecoded, DER data needs to be
+// preserved for a struct. To use it, the first field of the struct must have
+// this type. It's an error for any of the other fields to have this type.
+type RawContent []byte
+
+// Tagging
+
+// parseTagAndLength parses an ASN.1 tag and length pair from the given offset
+// into a byte array. It returns the parsed data and the new offset. SET and
+// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
+// don't distinguish between ordered and unordered objects in this code.
+func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) {
+ offset = initOffset
+ b := bytes[offset]
+ offset++
+ ret.class = int(b >> 6)
+ ret.isCompound = b&0x20 == 0x20
+ ret.tag = int(b & 0x1f)
+
+ // If the bottom five bits are set, then the tag number is actually base 128
+ // encoded afterwards
+ if ret.tag == 0x1f {
+ ret.tag, offset, err = parseBase128Int(bytes, offset)
+ if err != nil {
+ return
+ }
+ }
+ if offset >= len(bytes) {
+ err = SyntaxError{"truncated tag or length"}
+ return
+ }
+ b = bytes[offset]
+ offset++
+ if b&0x80 == 0 {
+ // The length is encoded in the bottom 7 bits.
+ ret.length = int(b & 0x7f)
+ } else {
+ // Bottom 7 bits give the number of length bytes to follow.
+ numBytes := int(b & 0x7f)
+ // We risk overflowing a signed 32-bit number if we accept more than 3 bytes.
+ if numBytes > 3 {
+ err = StructuralError{"length too large"}
+ return
+ }
+ if numBytes == 0 {
+ err = SyntaxError{"indefinite length found (not DER)"}
+ return
+ }
+ ret.length = 0
+ for i := 0; i < numBytes; i++ {
+ if offset >= len(bytes) {
+ err = SyntaxError{"truncated tag or length"}
+ return
+ }
+ b = bytes[offset]
+ offset++
+ ret.length <<= 8
+ ret.length |= int(b)
+ }
+ }
+
+ return
+}
+
+// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
+// a number of ASN.1 values from the given byte array and returns them as a
+// slice of Go values of the given type.
+func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflect.Type) (ret *reflect.SliceValue, err os.Error) {
+ expectedTag, compoundType, ok := getUniversalType(elemType)
+ if !ok {
+ err = StructuralError{"unknown Go type for slice"}
+ return
+ }
+
+ // First we iterate over the input and count the number of elements,
+ // checking that the types are correct in each case.
+ numElements := 0
+ for offset := 0; offset < len(bytes); {
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ // We pretend that GENERAL STRINGs are PRINTABLE STRINGs so
+ // that a sequence of them can be parsed into a []string.
+ if t.tag == tagGeneralString {
+ t.tag = tagPrintableString
+ }
+ if t.class != classUniversal || t.isCompound != compoundType || t.tag != expectedTag {
+ err = StructuralError{"sequence tag mismatch"}
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"truncated sequence"}
+ return
+ }
+ offset += t.length
+ numElements++
+ }
+ ret = reflect.MakeSlice(sliceType, numElements, numElements)
+ params := fieldParameters{}
+ offset := 0
+ for i := 0; i < numElements; i++ {
+ offset, err = parseField(ret.Elem(i), bytes, offset, params)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+
+var (
+ bitStringType = reflect.Typeof(BitString{})
+ objectIdentifierType = reflect.Typeof(ObjectIdentifier{})
+ enumeratedType = reflect.Typeof(Enumerated(0))
+ flagType = reflect.Typeof(Flag(false))
+ timeType = reflect.Typeof(&time.Time{})
+ rawValueType = reflect.Typeof(RawValue{})
+ rawContentsType = reflect.Typeof(RawContent(nil))
+)
+
+// invalidLength returns true iff offset + length > sliceLength, or if the
+// addition would overflow.
+func invalidLength(offset, length, sliceLength int) bool {
+ return offset+length < offset || offset+length > sliceLength
+}
+
+// parseField is the main parsing function. Given a byte array and an offset
+// into the array, it will try to parse a suitable ASN.1 value out and store it
+// in the given Value.
+func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) {
+ offset = initOffset
+ fieldType := v.Type()
+
+ // If we have run out of data, it may be that there are optional elements at the end.
+ if offset == len(bytes) {
+ if !setDefaultValue(v, params) {
+ err = SyntaxError{"sequence truncated"}
+ }
+ return
+ }
+
+ // Deal with raw values.
+ if fieldType == rawValueType {
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]}
+ offset += t.length
+ v.(*reflect.StructValue).Set(reflect.NewValue(result).(*reflect.StructValue))
+ return
+ }
+
+ // Deal with the ANY type.
+ if ifaceType, ok := fieldType.(*reflect.InterfaceType); ok && ifaceType.NumMethod() == 0 {
+ ifaceValue := v.(*reflect.InterfaceValue)
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ var result interface{}
+ if !t.isCompound && t.class == classUniversal {
+ innerBytes := bytes[offset : offset+t.length]
+ switch t.tag {
+ case tagPrintableString:
+ result, err = parsePrintableString(innerBytes)
+ case tagIA5String:
+ result, err = parseIA5String(innerBytes)
+ case tagT61String:
+ result, err = parseT61String(innerBytes)
+ case tagInteger:
+ result, err = parseInt64(innerBytes)
+ case tagBitString:
+ result, err = parseBitString(innerBytes)
+ case tagOID:
+ result, err = parseObjectIdentifier(innerBytes)
+ case tagUTCTime:
+ result, err = parseUTCTime(innerBytes)
+ case tagOctetString:
+ result = innerBytes
+ default:
+ // If we don't know how to handle the type, we just leave Value as nil.
+ }
+ }
+ offset += t.length
+ if err != nil {
+ return
+ }
+ if result != nil {
+ ifaceValue.Set(reflect.NewValue(result))
+ }
+ return
+ }
+ universalTag, compoundType, ok1 := getUniversalType(fieldType)
+ if !ok1 {
+ err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)}
+ return
+ }
+
+ t, offset, err := parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if params.explicit {
+ expectedClass := classContextSpecific
+ if params.application {
+ expectedClass = classApplication
+ }
+ if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) {
+ if t.length > 0 {
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ } else {
+ if fieldType != flagType {
+ err = StructuralError{"Zero length explicit tag was not an asn1.Flag"}
+ return
+ }
+
+ flagValue := v.(*reflect.BoolValue)
+ flagValue.Set(true)
+ return
+ }
+ } else {
+ // The tags didn't match, it might be an optional element.
+ ok := setDefaultValue(v, params)
+ if ok {
+ offset = initOffset
+ } else {
+ err = StructuralError{"explicitly tagged member didn't match"}
+ }
+ return
+ }
+ }
+
+ // Special case for strings: PrintableString and IA5String both map to
+ // the Go type string. getUniversalType returns the tag for
+ // PrintableString when it sees a string so, if we see an IA5String on
+ // the wire, we change the universal type to match.
+ if universalTag == tagPrintableString && t.tag == tagIA5String {
+ universalTag = tagIA5String
+ }
+ // Likewise for GeneralString
+ if universalTag == tagPrintableString && t.tag == tagGeneralString {
+ universalTag = tagGeneralString
+ }
+
+ // Special case for time: UTCTime and GeneralizedTime both map to the
+ // Go type time.Time.
+ if universalTag == tagUTCTime && t.tag == tagGeneralizedTime {
+ universalTag = tagGeneralizedTime
+ }
+
+ expectedClass := classUniversal
+ expectedTag := universalTag
+
+ if !params.explicit && params.tag != nil {
+ expectedClass = classContextSpecific
+ expectedTag = *params.tag
+ }
+
+ if !params.explicit && params.application && params.tag != nil {
+ expectedClass = classApplication
+ expectedTag = *params.tag
+ }
+
+ // We have unwrapped any explicit tagging at this point.
+ if t.class != expectedClass || t.tag != expectedTag || t.isCompound != compoundType {
+ // Tags don't match. Again, it could be an optional element.
+ ok := setDefaultValue(v, params)
+ if ok {
+ offset = initOffset
+ } else {
+ err = StructuralError{fmt.Sprintf("tags don't match (%d vs %+v) %+v %s @%d", expectedTag, t, params, fieldType.Name(), offset)}
+ }
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ innerBytes := bytes[offset : offset+t.length]
+ offset += t.length
+
+ // We deal with the structures defined in this package first.
+ switch fieldType {
+ case objectIdentifierType:
+ newSlice, err1 := parseObjectIdentifier(innerBytes)
+ sliceValue := v.(*reflect.SliceValue)
+ sliceValue.Set(reflect.MakeSlice(sliceValue.Type().(*reflect.SliceType), len(newSlice), len(newSlice)))
+ if err1 == nil {
+ reflect.Copy(sliceValue, reflect.NewValue(newSlice).(reflect.ArrayOrSliceValue))
+ }
+ err = err1
+ return
+ case bitStringType:
+ structValue := v.(*reflect.StructValue)
+ bs, err1 := parseBitString(innerBytes)
+ if err1 == nil {
+ structValue.Set(reflect.NewValue(bs).(*reflect.StructValue))
+ }
+ err = err1
+ return
+ case timeType:
+ ptrValue := v.(*reflect.PtrValue)
+ var time *time.Time
+ var err1 os.Error
+ if universalTag == tagUTCTime {
+ time, err1 = parseUTCTime(innerBytes)
+ } else {
+ time, err1 = parseGeneralizedTime(innerBytes)
+ }
+ if err1 == nil {
+ ptrValue.Set(reflect.NewValue(time).(*reflect.PtrValue))
+ }
+ err = err1
+ return
+ case enumeratedType:
+ parsedInt, err1 := parseInt(innerBytes)
+ enumValue := v.(*reflect.IntValue)
+ if err1 == nil {
+ enumValue.Set(int64(parsedInt))
+ }
+ err = err1
+ return
+ case flagType:
+ flagValue := v.(*reflect.BoolValue)
+ flagValue.Set(true)
+ return
+ }
+ switch val := v.(type) {
+ case *reflect.BoolValue:
+ parsedBool, err1 := parseBool(innerBytes)
+ if err1 == nil {
+ val.Set(parsedBool)
+ }
+ err = err1
+ return
+ case *reflect.IntValue:
+ switch val.Type().Kind() {
+ case reflect.Int:
+ parsedInt, err1 := parseInt(innerBytes)
+ if err1 == nil {
+ val.Set(int64(parsedInt))
+ }
+ err = err1
+ return
+ case reflect.Int64:
+ parsedInt, err1 := parseInt64(innerBytes)
+ if err1 == nil {
+ val.Set(parsedInt)
+ }
+ err = err1
+ return
+ }
+ case *reflect.StructValue:
+ structType := fieldType.(*reflect.StructType)
+
+ if structType.NumField() > 0 &&
+ structType.Field(0).Type == rawContentsType {
+ bytes := bytes[initOffset:offset]
+ val.Field(0).SetValue(reflect.NewValue(RawContent(bytes)))
+ }
+
+ innerOffset := 0
+ for i := 0; i < structType.NumField(); i++ {
+ field := structType.Field(i)
+ if i == 0 && field.Type == rawContentsType {
+ continue
+ }
+ innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag))
+ if err != nil {
+ return
+ }
+ }
+ // We allow extra bytes at the end of the SEQUENCE because
+ // adding elements to the end has been used in X.509 as the
+ // version numbers have increased.
+ return
+ case *reflect.SliceValue:
+ sliceType := fieldType.(*reflect.SliceType)
+ if sliceType.Elem().Kind() == reflect.Uint8 {
+ val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes)))
+ reflect.Copy(val, reflect.NewValue(innerBytes).(reflect.ArrayOrSliceValue))
+ return
+ }
+ newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem())
+ if err1 == nil {
+ val.Set(newSlice)
+ }
+ err = err1
+ return
+ case *reflect.StringValue:
+ var v string
+ switch universalTag {
+ case tagPrintableString:
+ v, err = parsePrintableString(innerBytes)
+ case tagIA5String:
+ v, err = parseIA5String(innerBytes)
+ case tagT61String:
+ v, err = parseT61String(innerBytes)
+ case tagGeneralString:
+ // GeneralString is specified in ISO-2022/ECMA-35,
+ // A brief review suggests that it includes structures
+ // that allow the encoding to change midstring and
+ // such. We give up and pass it as an 8-bit string.
+ v, err = parseT61String(innerBytes)
+ default:
+ err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)}
+ }
+ if err == nil {
+ val.Set(v)
+ }
+ return
+ }
+ err = StructuralError{"unknown Go type"}
+ return
+}
+
+// setDefaultValue is used to install a default value, from a tag string, into
+// a Value. It is successful is the field was optional, even if a default value
+// wasn't provided or it failed to install it into the Value.
+func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
+ if !params.optional {
+ return
+ }
+ ok = true
+ if params.defaultValue == nil {
+ return
+ }
+ switch val := v.(type) {
+ case *reflect.IntValue:
+ val.Set(*params.defaultValue)
+ }
+ return
+}
+
+// Unmarshal parses the DER-encoded ASN.1 data structure b
+// and uses the reflect package to fill in an arbitrary value pointed at by val.
+// Because Unmarshal uses the reflect package, the structs
+// being written to must use upper case field names.
+//
+// An ASN.1 INTEGER can be written to an int or int64.
+// If the encoded value does not fit in the Go type,
+// Unmarshal returns a parse error.
+//
+// An ASN.1 BIT STRING can be written to a BitString.
+//
+// An ASN.1 OCTET STRING can be written to a []byte.
+//
+// An ASN.1 OBJECT IDENTIFIER can be written to an
+// ObjectIdentifier.
+//
+// An ASN.1 ENUMERATED can be written to an Enumerated.
+//
+// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a *time.Time.
+//
+// An ASN.1 PrintableString or IA5String can be written to a string.
+//
+// Any of the above ASN.1 values can be written to an interface{}.
+// The value stored in the interface has the corresponding Go type.
+// For integers, that type is int64.
+//
+// An ASN.1 SEQUENCE OF x or SET OF x can be written
+// to a slice if an x can be written to the slice's element type.
+//
+// An ASN.1 SEQUENCE or SET can be written to a struct
+// if each of the elements in the sequence can be
+// written to the corresponding element in the struct.
+//
+// The following tags on struct fields have special meaning to Unmarshal:
+//
+// optional marks the field as ASN.1 OPTIONAL
+// [explicit] tag:x specifies the ASN.1 tag number; implies ASN.1 CONTEXT SPECIFIC
+// default:x sets the default value for optional integer fields
+//
+// If the type of the first field of a structure is RawContent then the raw
+// ASN1 contents of the struct will be stored in it.
+//
+// Other ASN.1 types are not supported; if it encounters them,
+// Unmarshal returns a parse error.
+func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) {
+ return UnmarshalWithParams(b, val, "")
+}
+
+// UnmarshalWithParams allows field parameters to be specified for the
+// top-level element. The form of the params is the same as the field tags.
+func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) {
+ v := reflect.NewValue(val).(*reflect.PtrValue).Elem()
+ offset, err := parseField(v, b, 0, parseFieldParameters(params))
+ if err != nil {
+ return nil, err
+ }
+ return b[offset:], nil
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The asn1 package implements parsing of DER-encoded ASN.1 data structures,
+// as defined in ITU-T Rec X.690.
+//
+// See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,''
+// http://luca.ntop.org/Teaching/Appunti/asn1.html.
+package asn1
+
+// ASN.1 is a syntax for specifying abstract objects and BER, DER, PER, XER etc
+// are different encoding formats for those objects. Here, we'll be dealing
+// with DER, the Distinguished Encoding Rules. DER is used in X.509 because
+// it's fast to parse and, unlike BER, has a unique encoding for every object.
+// When calculating hashes over objects, it's important that the resulting
+// bytes be the same at both ends and DER removes this margin of error.
+//
+// ASN.1 is very complex and this package doesn't attempt to implement
+// everything by any means.
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+ "time"
+)
+
+// A StructuralError suggests that the ASN.1 data is valid, but the Go type
+// which is receiving it doesn't match.
+type StructuralError struct {
+ Msg string
+}
+
+func (e StructuralError) String() string { return "ASN.1 structure error: " + e.Msg }
+
+// A SyntaxError suggests that the ASN.1 data is invalid.
+type SyntaxError struct {
+ Msg string
+}
+
+func (e SyntaxError) String() string { return "ASN.1 syntax error: " + e.Msg }
+
+// We start by dealing with each of the primitive types in turn.
+
+// BOOLEAN
+
+func parseBool(bytes []byte) (ret bool, err os.Error) {
+ if len(bytes) != 1 {
+ err = SyntaxError{"invalid boolean"}
+ return
+ }
+
+ return bytes[0] != 0, nil
+}
+
+// INTEGER
+
+// parseInt64 treats the given bytes as a big-endian, signed integer and
+// returns the result.
+func parseInt64(bytes []byte) (ret int64, err os.Error) {
+ if len(bytes) > 8 {
+ // We'll overflow an int64 in this case.
+ err = StructuralError{"integer too large"}
+ return
+ }
+ for bytesRead := 0; bytesRead < len(bytes); bytesRead++ {
+ ret <<= 8
+ ret |= int64(bytes[bytesRead])
+ }
+
+ // Shift up and down in order to sign extend the result.
+ ret <<= 64 - uint8(len(bytes))*8
+ ret >>= 64 - uint8(len(bytes))*8
+ return
+}
+
+// parseInt treats the given bytes as a big-endian, signed integer and returns
+// the result.
+func parseInt(bytes []byte) (int, os.Error) {
+ ret64, err := parseInt64(bytes)
+ if err != nil {
+ return 0, err
+ }
+ if ret64 != int64(int(ret64)) {
+ return 0, StructuralError{"integer too large"}
+ }
+ return int(ret64), nil
+}
+
+// BIT STRING
+
+// BitString is the structure to use when you want an ASN.1 BIT STRING type. A
+// bit string is padded up to the nearest byte in memory and the number of
+// valid bits is recorded. Padding bits will be zero.
+type BitString struct {
+ Bytes []byte // bits packed into bytes.
+ BitLength int // length in bits.
+}
+
+// At returns the bit at the given index. If the index is out of range it
+// returns false.
+func (b BitString) At(i int) int {
+ if i < 0 || i >= b.BitLength {
+ return 0
+ }
+ x := i / 8
+ y := 7 - uint(i%8)
+ return int(b.Bytes[x]>>y) & 1
+}
+
+// RightAlign returns a slice where the padding bits are at the beginning. The
+// slice may share memory with the BitString.
+func (b BitString) RightAlign() []byte {
+ shift := uint(8 - (b.BitLength % 8))
+ if shift == 8 || len(b.Bytes) == 0 {
+ return b.Bytes
+ }
+
+ a := make([]byte, len(b.Bytes))
+ a[0] = b.Bytes[0] >> shift
+ for i := 1; i < len(b.Bytes); i++ {
+ a[i] = b.Bytes[i-1] << (8 - shift)
+ a[i] |= b.Bytes[i] >> shift
+ }
+
+ return a
+}
+
+// parseBitString parses an ASN.1 bit string from the given byte array and returns it.
+func parseBitString(bytes []byte) (ret BitString, err os.Error) {
+ if len(bytes) == 0 {
+ err = SyntaxError{"zero length BIT STRING"}
+ return
+ }
+ paddingBits := int(bytes[0])
+ if paddingBits > 7 ||
+ len(bytes) == 1 && paddingBits > 0 ||
+ bytes[len(bytes)-1]&((1<<bytes[0])-1) != 0 {
+ err = SyntaxError{"invalid padding bits in BIT STRING"}
+ return
+ }
+ ret.BitLength = (len(bytes)-1)*8 - paddingBits
+ ret.Bytes = bytes[1:]
+ return
+}
+
+// OBJECT IDENTIFIER
+
+// An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER.
+type ObjectIdentifier []int
+
+// Equal returns true iff oi and other represent the same identifier.
+func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
+ if len(oi) != len(other) {
+ return false
+ }
+ for i := 0; i < len(oi); i++ {
+ if oi[i] != other[i] {
+ return false
+ }
+ }
+
+ return true
+}
+
+// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and
+// returns it. An object identifer is a sequence of variable length integers
+// that are assigned in a hierarachy.
+func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
+ if len(bytes) == 0 {
+ err = SyntaxError{"zero length OBJECT IDENTIFIER"}
+ return
+ }
+
+ // In the worst case, we get two elements from the first byte (which is
+ // encoded differently) and then every varint is a single byte long.
+ s = make([]int, len(bytes)+1)
+
+ // The first byte is 40*value1 + value2:
+ s[0] = int(bytes[0]) / 40
+ s[1] = int(bytes[0]) % 40
+ i := 2
+ for offset := 1; offset < len(bytes); i++ {
+ var v int
+ v, offset, err = parseBase128Int(bytes, offset)
+ if err != nil {
+ return
+ }
+ s[i] = v
+ }
+ s = s[0:i]
+ return
+}
+
+// ENUMERATED
+
+// An Enumerated is represented as a plain int.
+type Enumerated int
+
+
+// FLAG
+
+// A Flag accepts any data and is set to true if present.
+type Flag bool
+
+// parseBase128Int parses a base-128 encoded int from the given offset in the
+// given byte array. It returns the value and the new offset.
+func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) {
+ offset = initOffset
+ for shifted := 0; offset < len(bytes); shifted++ {
+ if shifted > 4 {
+ err = StructuralError{"base 128 integer too large"}
+ return
+ }
+ ret <<= 7
+ b := bytes[offset]
+ ret |= int(b & 0x7f)
+ offset++
+ if b&0x80 == 0 {
+ return
+ }
+ }
+ err = SyntaxError{"truncated base 128 integer"}
+ return
+}
+
+// UTCTime
+
+func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) {
+ s := string(bytes)
+ ret, err = time.Parse("0601021504Z0700", s)
+ if err == nil {
+ return
+ }
+ ret, err = time.Parse("060102150405Z0700", s)
+ return
+}
+
+// parseGeneralizedTime parses the GeneralizedTime from the given byte array
+// and returns the resulting time.
+func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) {
+ return time.Parse("20060102150405Z0700", string(bytes))
+}
+
+// PrintableString
+
+// parsePrintableString parses a ASN.1 PrintableString from the given byte
+// array and returns it.
+func parsePrintableString(bytes []byte) (ret string, err os.Error) {
+ for _, b := range bytes {
+ if !isPrintable(b) {
+ err = SyntaxError{"PrintableString contains invalid character"}
+ return
+ }
+ }
+ ret = string(bytes)
+ return
+}
+
+// isPrintable returns true iff the given b is in the ASN.1 PrintableString set.
+func isPrintable(b byte) bool {
+ return 'a' <= b && b <= 'z' ||
+ 'A' <= b && b <= 'Z' ||
+ '0' <= b && b <= '9' ||
+ '\'' <= b && b <= ')' ||
+ '+' <= b && b <= '/' ||
+ b == ' ' ||
+ b == ':' ||
+ b == '=' ||
+ b == '?' ||
+ // This is techincally not allowed in a PrintableString.
+ // However, x509 certificates with wildcard strings don't
+ // always use the correct string type so we permit it.
+ b == '*'
+}
+
+// IA5String
+
+// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given
+// byte array and returns it.
+func parseIA5String(bytes []byte) (ret string, err os.Error) {
+ for _, b := range bytes {
+ if b >= 0x80 {
+ err = SyntaxError{"IA5String contains invalid character"}
+ return
+ }
+ }
+ ret = string(bytes)
+ return
+}
+
+// T61String
+
+// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given
+// byte array and returns it.
+func parseT61String(bytes []byte) (ret string, err os.Error) {
+ return string(bytes), nil
+}
+
+// A RawValue represents an undecoded ASN.1 object.
+type RawValue struct {
+ Class, Tag int
+ IsCompound bool
+ Bytes []byte
+ FullBytes []byte // includes the tag and length
+}
+
+// RawContent is used to signal that the undecoded, DER data needs to be
+// preserved for a struct. To use it, the first field of the struct must have
+// this type. It's an error for any of the other fields to have this type.
+type RawContent []byte
+
+// Tagging
+
+// parseTagAndLength parses an ASN.1 tag and length pair from the given offset
+// into a byte array. It returns the parsed data and the new offset. SET and
+// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
+// don't distinguish between ordered and unordered objects in this code.
+func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) {
+ offset = initOffset
+ b := bytes[offset]
+ offset++
+ ret.class = int(b >> 6)
+ ret.isCompound = b&0x20 == 0x20
+ ret.tag = int(b & 0x1f)
+
+ // If the bottom five bits are set, then the tag number is actually base 128
+ // encoded afterwards
+ if ret.tag == 0x1f {
+ ret.tag, offset, err = parseBase128Int(bytes, offset)
+ if err != nil {
+ return
+ }
+ }
+ if offset >= len(bytes) {
+ err = SyntaxError{"truncated tag or length"}
+ return
+ }
+ b = bytes[offset]
+ offset++
+ if b&0x80 == 0 {
+ // The length is encoded in the bottom 7 bits.
+ ret.length = int(b & 0x7f)
+ } else {
+ // Bottom 7 bits give the number of length bytes to follow.
+ numBytes := int(b & 0x7f)
+ // We risk overflowing a signed 32-bit number if we accept more than 3 bytes.
+ if numBytes > 3 {
+ err = StructuralError{"length too large"}
+ return
+ }
+ if numBytes == 0 {
+ err = SyntaxError{"indefinite length found (not DER)"}
+ return
+ }
+ ret.length = 0
+ for i := 0; i < numBytes; i++ {
+ if offset >= len(bytes) {
+ err = SyntaxError{"truncated tag or length"}
+ return
+ }
+ b = bytes[offset]
+ offset++
+ ret.length <<= 8
+ ret.length |= int(b)
+ }
+ }
+
+ return
+}
+
+// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
+// a number of ASN.1 values from the given byte array and returns them as a
+// slice of Go values of the given type.
+func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) {
+ expectedTag, compoundType, ok := getUniversalType(elemType)
+ if !ok {
+ err = StructuralError{"unknown Go type for slice"}
+ return
+ }
+
+ // First we iterate over the input and count the number of elements,
+ // checking that the types are correct in each case.
+ numElements := 0
+ for offset := 0; offset < len(bytes); {
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ // We pretend that GENERAL STRINGs are PRINTABLE STRINGs so
+ // that a sequence of them can be parsed into a []string.
+ if t.tag == tagGeneralString {
+ t.tag = tagPrintableString
+ }
+ if t.class != classUniversal || t.isCompound != compoundType || t.tag != expectedTag {
+ err = StructuralError{"sequence tag mismatch"}
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"truncated sequence"}
+ return
+ }
+ offset += t.length
+ numElements++
+ }
+ ret = reflect.MakeSlice(sliceType, numElements, numElements)
+ params := fieldParameters{}
+ offset := 0
+ for i := 0; i < numElements; i++ {
+ offset, err = parseField(ret.Index(i), bytes, offset, params)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+
+var (
+ bitStringType = reflect.Typeof(BitString{})
+ objectIdentifierType = reflect.Typeof(ObjectIdentifier{})
+ enumeratedType = reflect.Typeof(Enumerated(0))
+ flagType = reflect.Typeof(Flag(false))
+ timeType = reflect.Typeof(&time.Time{})
+ rawValueType = reflect.Typeof(RawValue{})
+ rawContentsType = reflect.Typeof(RawContent(nil))
+)
+
+// invalidLength returns true iff offset + length > sliceLength, or if the
+// addition would overflow.
+func invalidLength(offset, length, sliceLength int) bool {
+ return offset+length < offset || offset+length > sliceLength
+}
+
+// parseField is the main parsing function. Given a byte array and an offset
+// into the array, it will try to parse a suitable ASN.1 value out and store it
+// in the given Value.
+func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) {
+ offset = initOffset
+ fieldType := v.Type()
+
+ // If we have run out of data, it may be that there are optional elements at the end.
+ if offset == len(bytes) {
+ if !setDefaultValue(v, params) {
+ err = SyntaxError{"sequence truncated"}
+ }
+ return
+ }
+
+ // Deal with raw values.
+ if fieldType == rawValueType {
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]}
+ offset += t.length
+ v.Set(reflect.NewValue(result))
+ return
+ }
+
+ // Deal with the ANY type.
+ if ifaceType := fieldType; ifaceType.Kind() == reflect.Interface && ifaceType.NumMethod() == 0 {
+ ifaceValue := v
+ var t tagAndLength
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ var result interface{}
+ if !t.isCompound && t.class == classUniversal {
+ innerBytes := bytes[offset : offset+t.length]
+ switch t.tag {
+ case tagPrintableString:
+ result, err = parsePrintableString(innerBytes)
+ case tagIA5String:
+ result, err = parseIA5String(innerBytes)
+ case tagT61String:
+ result, err = parseT61String(innerBytes)
+ case tagInteger:
+ result, err = parseInt64(innerBytes)
+ case tagBitString:
+ result, err = parseBitString(innerBytes)
+ case tagOID:
+ result, err = parseObjectIdentifier(innerBytes)
+ case tagUTCTime:
+ result, err = parseUTCTime(innerBytes)
+ case tagOctetString:
+ result = innerBytes
+ default:
+ // If we don't know how to handle the type, we just leave Value as nil.
+ }
+ }
+ offset += t.length
+ if err != nil {
+ return
+ }
+ if result != nil {
+ ifaceValue.Set(reflect.NewValue(result))
+ }
+ return
+ }
+ universalTag, compoundType, ok1 := getUniversalType(fieldType)
+ if !ok1 {
+ err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)}
+ return
+ }
+
+ t, offset, err := parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ if params.explicit {
+ expectedClass := classContextSpecific
+ if params.application {
+ expectedClass = classApplication
+ }
+ if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) {
+ if t.length > 0 {
+ t, offset, err = parseTagAndLength(bytes, offset)
+ if err != nil {
+ return
+ }
+ } else {
+ if fieldType != flagType {
+ err = StructuralError{"Zero length explicit tag was not an asn1.Flag"}
+ return
+ }
+
+ flagValue := v
+ flagValue.SetBool(true)
+ return
+ }
+ } else {
+ // The tags didn't match, it might be an optional element.
+ ok := setDefaultValue(v, params)
+ if ok {
+ offset = initOffset
+ } else {
+ err = StructuralError{"explicitly tagged member didn't match"}
+ }
+ return
+ }
+ }
+
+ // Special case for strings: PrintableString and IA5String both map to
+ // the Go type string. getUniversalType returns the tag for
+ // PrintableString when it sees a string so, if we see an IA5String on
+ // the wire, we change the universal type to match.
+ if universalTag == tagPrintableString && t.tag == tagIA5String {
+ universalTag = tagIA5String
+ }
+ // Likewise for GeneralString
+ if universalTag == tagPrintableString && t.tag == tagGeneralString {
+ universalTag = tagGeneralString
+ }
+
+ // Special case for time: UTCTime and GeneralizedTime both map to the
+ // Go type time.Time.
+ if universalTag == tagUTCTime && t.tag == tagGeneralizedTime {
+ universalTag = tagGeneralizedTime
+ }
+
+ expectedClass := classUniversal
+ expectedTag := universalTag
+
+ if !params.explicit && params.tag != nil {
+ expectedClass = classContextSpecific
+ expectedTag = *params.tag
+ }
+
+ if !params.explicit && params.application && params.tag != nil {
+ expectedClass = classApplication
+ expectedTag = *params.tag
+ }
+
+ // We have unwrapped any explicit tagging at this point.
+ if t.class != expectedClass || t.tag != expectedTag || t.isCompound != compoundType {
+ // Tags don't match. Again, it could be an optional element.
+ ok := setDefaultValue(v, params)
+ if ok {
+ offset = initOffset
+ } else {
+ err = StructuralError{fmt.Sprintf("tags don't match (%d vs %+v) %+v %s @%d", expectedTag, t, params, fieldType.Name(), offset)}
+ }
+ return
+ }
+ if invalidLength(offset, t.length, len(bytes)) {
+ err = SyntaxError{"data truncated"}
+ return
+ }
+ innerBytes := bytes[offset : offset+t.length]
+ offset += t.length
+
+ // We deal with the structures defined in this package first.
+ switch fieldType {
+ case objectIdentifierType:
+ newSlice, err1 := parseObjectIdentifier(innerBytes)
+ sliceValue := v
+ sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), len(newSlice), len(newSlice)))
+ if err1 == nil {
+ reflect.Copy(sliceValue, reflect.NewValue(newSlice))
+ }
+ err = err1
+ return
+ case bitStringType:
+ structValue := v
+ bs, err1 := parseBitString(innerBytes)
+ if err1 == nil {
+ structValue.Set(reflect.NewValue(bs))
+ }
+ err = err1
+ return
+ case timeType:
+ ptrValue := v
+ var time *time.Time
+ var err1 os.Error
+ if universalTag == tagUTCTime {
+ time, err1 = parseUTCTime(innerBytes)
+ } else {
+ time, err1 = parseGeneralizedTime(innerBytes)
+ }
+ if err1 == nil {
+ ptrValue.Set(reflect.NewValue(time))
+ }
+ err = err1
+ return
+ case enumeratedType:
+ parsedInt, err1 := parseInt(innerBytes)
+ enumValue := v
+ if err1 == nil {
+ enumValue.SetInt(int64(parsedInt))
+ }
+ err = err1
+ return
+ case flagType:
+ flagValue := v
+ flagValue.SetBool(true)
+ return
+ }
+ switch val := v; val.Kind() {
+ case reflect.Bool:
+ parsedBool, err1 := parseBool(innerBytes)
+ if err1 == nil {
+ val.SetBool(parsedBool)
+ }
+ err = err1
+ return
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ switch val.Type().Kind() {
+ case reflect.Int:
+ parsedInt, err1 := parseInt(innerBytes)
+ if err1 == nil {
+ val.SetInt(int64(parsedInt))
+ }
+ err = err1
+ return
+ case reflect.Int64:
+ parsedInt, err1 := parseInt64(innerBytes)
+ if err1 == nil {
+ val.SetInt(parsedInt)
+ }
+ err = err1
+ return
+ }
+ case reflect.Struct:
+ structType := fieldType
+
+ if structType.NumField() > 0 &&
+ structType.Field(0).Type == rawContentsType {
+ bytes := bytes[initOffset:offset]
+ val.Field(0).Set(reflect.NewValue(RawContent(bytes)))
+ }
+
+ innerOffset := 0
+ for i := 0; i < structType.NumField(); i++ {
+ field := structType.Field(i)
+ if i == 0 && field.Type == rawContentsType {
+ continue
+ }
+ innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag))
+ if err != nil {
+ return
+ }
+ }
+ // We allow extra bytes at the end of the SEQUENCE because
+ // adding elements to the end has been used in X.509 as the
+ // version numbers have increased.
+ return
+ case reflect.Slice:
+ sliceType := fieldType
+ if sliceType.Elem().Kind() == reflect.Uint8 {
+ val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes)))
+ reflect.Copy(val, reflect.NewValue(innerBytes))
+ return
+ }
+ newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem())
+ if err1 == nil {
+ val.Set(newSlice)
+ }
+ err = err1
+ return
+ case reflect.String:
+ var v string
+ switch universalTag {
+ case tagPrintableString:
+ v, err = parsePrintableString(innerBytes)
+ case tagIA5String:
+ v, err = parseIA5String(innerBytes)
+ case tagT61String:
+ v, err = parseT61String(innerBytes)
+ case tagGeneralString:
+ // GeneralString is specified in ISO-2022/ECMA-35,
+ // A brief review suggests that it includes structures
+ // that allow the encoding to change midstring and
+ // such. We give up and pass it as an 8-bit string.
+ v, err = parseT61String(innerBytes)
+ default:
+ err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)}
+ }
+ if err == nil {
+ val.SetString(v)
+ }
+ return
+ }
+ err = StructuralError{"unknown Go type"}
+ return
+}
+
+// setDefaultValue is used to install a default value, from a tag string, into
+// a Value. It is successful is the field was optional, even if a default value
+// wasn't provided or it failed to install it into the Value.
+func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
+ if !params.optional {
+ return
+ }
+ ok = true
+ if params.defaultValue == nil {
+ return
+ }
+ switch val := v; val.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ val.SetInt(*params.defaultValue)
+ }
+ return
+}
+
+// Unmarshal parses the DER-encoded ASN.1 data structure b
+// and uses the reflect package to fill in an arbitrary value pointed at by val.
+// Because Unmarshal uses the reflect package, the structs
+// being written to must use upper case field names.
+//
+// An ASN.1 INTEGER can be written to an int or int64.
+// If the encoded value does not fit in the Go type,
+// Unmarshal returns a parse error.
+//
+// An ASN.1 BIT STRING can be written to a BitString.
+//
+// An ASN.1 OCTET STRING can be written to a []byte.
+//
+// An ASN.1 OBJECT IDENTIFIER can be written to an
+// ObjectIdentifier.
+//
+// An ASN.1 ENUMERATED can be written to an Enumerated.
+//
+// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a *time.Time.
+//
+// An ASN.1 PrintableString or IA5String can be written to a string.
+//
+// Any of the above ASN.1 values can be written to an interface{}.
+// The value stored in the interface has the corresponding Go type.
+// For integers, that type is int64.
+//
+// An ASN.1 SEQUENCE OF x or SET OF x can be written
+// to a slice if an x can be written to the slice's element type.
+//
+// An ASN.1 SEQUENCE or SET can be written to a struct
+// if each of the elements in the sequence can be
+// written to the corresponding element in the struct.
+//
+// The following tags on struct fields have special meaning to Unmarshal:
+//
+// optional marks the field as ASN.1 OPTIONAL
+// [explicit] tag:x specifies the ASN.1 tag number; implies ASN.1 CONTEXT SPECIFIC
+// default:x sets the default value for optional integer fields
+//
+// If the type of the first field of a structure is RawContent then the raw
+// ASN1 contents of the struct will be stored in it.
+//
+// Other ASN.1 types are not supported; if it encounters them,
+// Unmarshal returns a parse error.
+func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) {
+ return UnmarshalWithParams(b, val, "")
+}
+
+// UnmarshalWithParams allows field parameters to be specified for the
+// top-level element. The form of the params is the same as the field tags.
+func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) {
+ v := reflect.NewValue(val).Elem()
+ offset, err := parseField(v, b, 0, parseFieldParameters(params))
+ if err != nil {
+ return nil, err
+ }
+ return b[offset:], nil
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/* The datafmt package implements syntax-directed, type-driven formatting
+ of arbitrary data structures. Formatting a data structure consists of
+ two phases: first, a parser reads a format specification and builds a
+ "compiled" format. Then, the format can be applied repeatedly to
+ arbitrary values. Applying a format to a value evaluates to a []byte
+ containing the formatted value bytes, or nil.
+
+ A format specification is a set of package declarations and format rules:
+
+ Format = [ Entry { ";" Entry } [ ";" ] ] .
+ Entry = PackageDecl | FormatRule .
+
+ (The syntax of a format specification is presented in the same EBNF
+ notation as used in the Go language specification. The syntax of white
+ space, comments, identifiers, and string literals is the same as in Go.)
+
+ A package declaration binds a package name (such as 'ast') to a
+ package import path (such as '"go/ast"'). Each package used (in
+ a type name, see below) must be declared once before use.
+
+ PackageDecl = PackageName ImportPath .
+ PackageName = identifier .
+ ImportPath = string .
+
+ A format rule binds a rule name to a format expression. A rule name
+ may be a type name or one of the special names 'default' or '/'.
+ A type name may be the name of a predeclared type (for example, 'int',
+ 'float32', etc.), the package-qualified name of a user-defined type
+ (for example, 'ast.MapType'), or an identifier indicating the structure
+ of unnamed composite types ('array', 'chan', 'func', 'interface', 'map',
+ or 'ptr'). Each rule must have a unique name; rules can be declared in
+ any order.
+
+ FormatRule = RuleName "=" Expression .
+ RuleName = TypeName | "default" | "/" .
+ TypeName = [ PackageName "." ] identifier .
+
+ To format a value, the value's type name is used to select the format rule
+ (there is an override mechanism, see below). The format expression of the
+ selected rule specifies how the value is formatted. Each format expression,
+ when applied to a value, evaluates to a byte sequence or nil.
+
+ In its most general form, a format expression is a list of alternatives,
+ each of which is a sequence of operands:
+
+ Expression = [ Sequence ] { "|" [ Sequence ] } .
+ Sequence = Operand { Operand } .
+
+ The formatted result produced by an expression is the result of the first
+ alternative sequence that evaluates to a non-nil result; if there is no
+ such alternative, the expression evaluates to nil. The result produced by
+ an operand sequence is the concatenation of the results of its operands.
+ If any operand in the sequence evaluates to nil, the entire sequence
+ evaluates to nil.
+
+ There are five kinds of operands:
+
+ Operand = Literal | Field | Group | Option | Repetition .
+
+ Literals evaluate to themselves, with two substitutions. First,
+ %-formats expand in the manner of fmt.Printf, with the current value
+ passed as the parameter. Second, the current indentation (see below)
+ is inserted after every newline or form feed character.
+
+ Literal = string .
+
+ This table shows string literals applied to the value 42 and the
+ corresponding formatted result:
+
+ "foo" foo
+ "%x" 2a
+ "x = %d" x = 42
+ "%#x = %d" 0x2a = 42
+
+ A field operand is a field name optionally followed by an alternate
+ rule name. The field name may be an identifier or one of the special
+ names @ or *.
+
+ Field = FieldName [ ":" RuleName ] .
+ FieldName = identifier | "@" | "*" .
+
+ If the field name is an identifier, the current value must be a struct,
+ and there must be a field with that name in the struct. The same lookup
+ rules apply as in the Go language (for instance, the name of an anonymous
+ field is the unqualified type name). The field name denotes the field
+ value in the struct. If the field is not found, formatting is aborted
+ and an error message is returned. (TODO consider changing the semantics
+ such that if a field is not found, it evaluates to nil).
+
+ The special name '@' denotes the current value.
+
+ The meaning of the special name '*' depends on the type of the current
+ value:
+
+ array, slice types array, slice element (inside {} only, see below)
+ interfaces value stored in interface
+ pointers value pointed to by pointer
+
+ (Implementation restriction: channel, function and map types are not
+ supported due to missing reflection support).
+
+ Fields are evaluated as follows: If the field value is nil, or an array
+ or slice element does not exist, the result is nil (see below for details
+ on array/slice elements). If the value is not nil the field value is
+ formatted (recursively) using the rule corresponding to its type name,
+ or the alternate rule name, if given.
+
+ The following example shows a complete format specification for a
+ struct 'myPackage.Point'. Assume the package
+
+ package myPackage // in directory myDir/myPackage
+ type Point struct {
+ name string;
+ x, y int;
+ }
+
+ Applying the format specification
+
+ myPackage "myDir/myPackage";
+ int = "%d";
+ hexInt = "0x%x";
+ string = "---%s---";
+ myPackage.Point = name "{" x ", " y:hexInt "}";
+
+ to the value myPackage.Point{"foo", 3, 15} results in
+
+ ---foo---{3, 0xf}
+
+ Finally, an operand may be a grouped, optional, or repeated expression.
+ A grouped expression ("group") groups a more complex expression (body)
+ so that it can be used in place of a single operand:
+
+ Group = "(" [ Indentation ">>" ] Body ")" .
+ Indentation = Expression .
+ Body = Expression .
+
+ A group body may be prefixed by an indentation expression followed by '>>'.
+ The indentation expression is applied to the current value like any other
+ expression and the result, if not nil, is appended to the current indentation
+ during the evaluation of the body (see also formatting state, below).
+
+ An optional expression ("option") is enclosed in '[]' brackets.
+
+ Option = "[" Body "]" .
+
+ An option evaluates to its body, except that if the body evaluates to nil,
+ the option expression evaluates to an empty []byte. Thus an option's purpose
+ is to protect the expression containing the option from a nil operand.
+
+ A repeated expression ("repetition") is enclosed in '{}' braces.
+
+ Repetition = "{" Body [ "/" Separator ] "}" .
+ Separator = Expression .
+
+ A repeated expression is evaluated as follows: The body is evaluated
+ repeatedly and its results are concatenated until the body evaluates
+ to nil. The result of the repetition is the (possibly empty) concatenation,
+ but it is never nil. An implicit index is supplied for the evaluation of
+ the body: that index is used to address elements of arrays or slices. If
+ the corresponding elements do not exist, the field denoting the element
+ evaluates to nil (which in turn may terminate the repetition).
+
+ The body of a repetition may be followed by a '/' and a "separator"
+ expression. If the separator is present, it is invoked between repetitions
+ of the body.
+
+ The following example shows a complete format specification for formatting
+ a slice of unnamed type. Applying the specification
+
+ int = "%b";
+ array = { * / ", " }; // array is the type name for an unnamed slice
+
+ to the value '[]int{2, 3, 5, 7}' results in
+
+ 10, 11, 101, 111
+
+ Default rule: If a format rule named 'default' is present, it is used for
+ formatting a value if no other rule was found. A common default rule is
+
+ default = "%v"
+
+ to provide default formatting for basic types without having to specify
+ a specific rule for each basic type.
+
+ Global separator rule: If a format rule named '/' is present, it is
+ invoked with the current value between literals. If the separator
+ expression evaluates to nil, it is ignored.
+
+ For instance, a global separator rule may be used to punctuate a sequence
+ of values with commas. The rules:
+
+ default = "%v";
+ / = ", ";
+
+ will format an argument list by printing each one in its default format,
+ separated by a comma and a space.
+*/
+package datafmt
+
+import (
+ "bytes"
+ "fmt"
+ "go/token"
+ "io"
+ "os"
+ "reflect"
+ "runtime"
+)
+
+
+// ----------------------------------------------------------------------------
+// Format representation
+
+// Custom formatters implement the Formatter function type.
+// A formatter is invoked with the current formatting state, the
+// value to format, and the rule name under which the formatter
+// was installed (the same formatter function may be installed
+// under different names). The formatter may access the current state
+// to guide formatting and use State.Write to append to the state's
+// output.
+//
+// A formatter must return a boolean value indicating if it evaluated
+// to a non-nil value (true), or a nil value (false).
+//
+type Formatter func(state *State, value interface{}, ruleName string) bool
+
+
+// A FormatterMap is a set of custom formatters.
+// It maps a rule name to a formatter function.
+//
+type FormatterMap map[string]Formatter
+
+
+// A parsed format expression is built from the following nodes.
+//
+type (
+ expr interface{}
+
+ alternatives []expr // x | y | z
+
+ sequence []expr // x y z
+
+ literal [][]byte // a list of string segments, possibly starting with '%'
+
+ field struct {
+ fieldName string // including "@", "*"
+ ruleName string // "" if no rule name specified
+ }
+
+ group struct {
+ indent, body expr // (indent >> body)
+ }
+
+ option struct {
+ body expr // [body]
+ }
+
+ repetition struct {
+ body, separator expr // {body / separator}
+ }
+
+ custom struct {
+ ruleName string
+ fun Formatter
+ }
+)
+
+
+// A Format is the result of parsing a format specification.
+// The format may be applied repeatedly to format values.
+//
+type Format map[string]expr
+
+
+// ----------------------------------------------------------------------------
+// Formatting
+
+// An application-specific environment may be provided to Format.Apply;
+// the environment is available inside custom formatters via State.Env().
+// Environments must implement copying; the Copy method must return an
+// complete copy of the receiver. This is necessary so that the formatter
+// can save and restore an environment (in case of an absent expression).
+//
+// If the Environment doesn't change during formatting (this is under
+// control of the custom formatters), the Copy function can simply return
+// the receiver, and thus can be very light-weight.
+//
+type Environment interface {
+ Copy() Environment
+}
+
+
+// State represents the current formatting state.
+// It is provided as argument to custom formatters.
+//
+type State struct {
+ fmt Format // format in use
+ env Environment // user-supplied environment
+ errors chan os.Error // not chan *Error (errors <- nil would be wrong!)
+ hasOutput bool // true after the first literal has been written
+ indent bytes.Buffer // current indentation
+ output bytes.Buffer // format output
+ linePos token.Position // position of line beginning (Column == 0)
+ default_ expr // possibly nil
+ separator expr // possibly nil
+}
+
+
+func newState(fmt Format, env Environment, errors chan os.Error) *State {
+ s := new(State)
+ s.fmt = fmt
+ s.env = env
+ s.errors = errors
+ s.linePos = token.Position{Line: 1}
+
+ // if we have a default rule, cache it's expression for fast access
+ if x, found := fmt["default"]; found {
+ s.default_ = x
+ }
+
+ // if we have a global separator rule, cache it's expression for fast access
+ if x, found := fmt["/"]; found {
+ s.separator = x
+ }
+
+ return s
+}
+
+
+// Env returns the environment passed to Format.Apply.
+func (s *State) Env() interface{} { return s.env }
+
+
+// LinePos returns the position of the current line beginning
+// in the state's output buffer. Line numbers start at 1.
+//
+func (s *State) LinePos() token.Position { return s.linePos }
+
+
+// Pos returns the position of the next byte to be written to the
+// output buffer. Line numbers start at 1.
+//
+func (s *State) Pos() token.Position {
+ offs := s.output.Len()
+ return token.Position{Line: s.linePos.Line, Column: offs - s.linePos.Offset, Offset: offs}
+}
+
+
+// Write writes data to the output buffer, inserting the indentation
+// string after each newline or form feed character. It cannot return an error.
+//
+func (s *State) Write(data []byte) (int, os.Error) {
+ n := 0
+ i0 := 0
+ for i, ch := range data {
+ if ch == '\n' || ch == '\f' {
+ // write text segment and indentation
+ n1, _ := s.output.Write(data[i0 : i+1])
+ n2, _ := s.output.Write(s.indent.Bytes())
+ n += n1 + n2
+ i0 = i + 1
+ s.linePos.Offset = s.output.Len()
+ s.linePos.Line++
+ }
+ }
+ n3, _ := s.output.Write(data[i0:])
+ return n + n3, nil
+}
+
+
+type checkpoint struct {
+ env Environment
+ hasOutput bool
+ outputLen int
+ linePos token.Position
+}
+
+
+func (s *State) save() checkpoint {
+ saved := checkpoint{nil, s.hasOutput, s.output.Len(), s.linePos}
+ if s.env != nil {
+ saved.env = s.env.Copy()
+ }
+ return saved
+}
+
+
+func (s *State) restore(m checkpoint) {
+ s.env = m.env
+ s.output.Truncate(m.outputLen)
+}
+
+
+func (s *State) error(msg string) {
+ s.errors <- os.NewError(msg)
+ runtime.Goexit()
+}
+
+
+// TODO At the moment, unnamed types are simply mapped to the default
+// names below. For instance, all unnamed arrays are mapped to
+// 'array' which is not really sufficient. Eventually one may want
+// to be able to specify rules for say an unnamed slice of T.
+//
+
+func typename(typ reflect.Type) string {
+ switch typ.(type) {
+ case *reflect.ArrayType:
+ return "array"
+ case *reflect.SliceType:
+ return "array"
+ case *reflect.ChanType:
+ return "chan"
+ case *reflect.FuncType:
+ return "func"
+ case *reflect.InterfaceType:
+ return "interface"
+ case *reflect.MapType:
+ return "map"
+ case *reflect.PtrType:
+ return "ptr"
+ }
+ return typ.String()
+}
+
+func (s *State) getFormat(name string) expr {
+ if fexpr, found := s.fmt[name]; found {
+ return fexpr
+ }
+
+ if s.default_ != nil {
+ return s.default_
+ }
+
+ s.error(fmt.Sprintf("no format rule for type: '%s'", name))
+ return nil
+}
+
+
+// eval applies a format expression fexpr to a value. If the expression
+// evaluates internally to a non-nil []byte, that slice is appended to
+// the state's output buffer and eval returns true. Otherwise, eval
+// returns false and the state remains unchanged.
+//
+func (s *State) eval(fexpr expr, value reflect.Value, index int) bool {
+ // an empty format expression always evaluates
+ // to a non-nil (but empty) []byte
+ if fexpr == nil {
+ return true
+ }
+
+ switch t := fexpr.(type) {
+ case alternatives:
+ // append the result of the first alternative that evaluates to
+ // a non-nil []byte to the state's output
+ mark := s.save()
+ for _, x := range t {
+ if s.eval(x, value, index) {
+ return true
+ }
+ s.restore(mark)
+ }
+ return false
+
+ case sequence:
+ // append the result of all operands to the state's output
+ // unless a nil result is encountered
+ mark := s.save()
+ for _, x := range t {
+ if !s.eval(x, value, index) {
+ s.restore(mark)
+ return false
+ }
+ }
+ return true
+
+ case literal:
+ // write separator, if any
+ if s.hasOutput {
+ // not the first literal
+ if s.separator != nil {
+ sep := s.separator // save current separator
+ s.separator = nil // and disable it (avoid recursion)
+ mark := s.save()
+ if !s.eval(sep, value, index) {
+ s.restore(mark)
+ }
+ s.separator = sep // enable it again
+ }
+ }
+ s.hasOutput = true
+ // write literal segments
+ for _, lit := range t {
+ if len(lit) > 1 && lit[0] == '%' {
+ // segment contains a %-format at the beginning
+ if lit[1] == '%' {
+ // "%%" is printed as a single "%"
+ s.Write(lit[1:])
+ } else {
+ // use s instead of s.output to get indentation right
+ fmt.Fprintf(s, string(lit), value.Interface())
+ }
+ } else {
+ // segment contains no %-formats
+ s.Write(lit)
+ }
+ }
+ return true // a literal never evaluates to nil
+
+ case *field:
+ // determine field value
+ switch t.fieldName {
+ case "@":
+ // field value is current value
+
+ case "*":
+ // indirection: operation is type-specific
+ switch v := value.(type) {
+ case *reflect.ArrayValue:
+ if v.Len() <= index {
+ return false
+ }
+ value = v.Elem(index)
+
+ case *reflect.SliceValue:
+ if v.IsNil() || v.Len() <= index {
+ return false
+ }
+ value = v.Elem(index)
+
+ case *reflect.MapValue:
+ s.error("reflection support for maps incomplete")
+
+ case *reflect.PtrValue:
+ if v.IsNil() {
+ return false
+ }
+ value = v.Elem()
+
+ case *reflect.InterfaceValue:
+ if v.IsNil() {
+ return false
+ }
+ value = v.Elem()
+
+ case *reflect.ChanValue:
+ s.error("reflection support for chans incomplete")
+
+ case *reflect.FuncValue:
+ s.error("reflection support for funcs incomplete")
+
+ default:
+ s.error(fmt.Sprintf("error: * does not apply to `%s`", value.Type()))
+ }
+
+ default:
+ // value is value of named field
+ var field reflect.Value
+ if sval, ok := value.(*reflect.StructValue); ok {
+ field = sval.FieldByName(t.fieldName)
+ if field == nil {
+ // TODO consider just returning false in this case
+ s.error(fmt.Sprintf("error: no field `%s` in `%s`", t.fieldName, value.Type()))
+ }
+ }
+ value = field
+ }
+
+ // determine rule
+ ruleName := t.ruleName
+ if ruleName == "" {
+ // no alternate rule name, value type determines rule
+ ruleName = typename(value.Type())
+ }
+ fexpr = s.getFormat(ruleName)
+
+ mark := s.save()
+ if !s.eval(fexpr, value, index) {
+ s.restore(mark)
+ return false
+ }
+ return true
+
+ case *group:
+ // remember current indentation
+ indentLen := s.indent.Len()
+
+ // update current indentation
+ mark := s.save()
+ s.eval(t.indent, value, index)
+ // if the indentation evaluates to nil, the state's output buffer
+ // didn't change - either way it's ok to append the difference to
+ // the current identation
+ s.indent.Write(s.output.Bytes()[mark.outputLen:s.output.Len()])
+ s.restore(mark)
+
+ // format group body
+ mark = s.save()
+ b := true
+ if !s.eval(t.body, value, index) {
+ s.restore(mark)
+ b = false
+ }
+
+ // reset indentation
+ s.indent.Truncate(indentLen)
+ return b
+
+ case *option:
+ // evaluate the body and append the result to the state's output
+ // buffer unless the result is nil
+ mark := s.save()
+ if !s.eval(t.body, value, 0) { // TODO is 0 index correct?
+ s.restore(mark)
+ }
+ return true // an option never evaluates to nil
+
+ case *repetition:
+ // evaluate the body and append the result to the state's output
+ // buffer until a result is nil
+ for i := 0; ; i++ {
+ mark := s.save()
+ // write separator, if any
+ if i > 0 && t.separator != nil {
+ // nil result from separator is ignored
+ mark := s.save()
+ if !s.eval(t.separator, value, i) {
+ s.restore(mark)
+ }
+ }
+ if !s.eval(t.body, value, i) {
+ s.restore(mark)
+ break
+ }
+ }
+ return true // a repetition never evaluates to nil
+
+ case *custom:
+ // invoke the custom formatter to obtain the result
+ mark := s.save()
+ if !t.fun(s, value.Interface(), t.ruleName) {
+ s.restore(mark)
+ return false
+ }
+ return true
+ }
+
+ panic("unreachable")
+ return false
+}
+
+
+// Eval formats each argument according to the format
+// f and returns the resulting []byte and os.Error. If
+// an error occurred, the []byte contains the partially
+// formatted result. An environment env may be passed
+// in which is available in custom formatters through
+// the state parameter.
+//
+func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) {
+ if f == nil {
+ return nil, os.NewError("format is nil")
+ }
+
+ errors := make(chan os.Error)
+ s := newState(f, env, errors)
+
+ go func() {
+ for _, v := range args {
+ fld := reflect.NewValue(v)
+ if fld == nil {
+ errors <- os.NewError("nil argument")
+ return
+ }
+ mark := s.save()
+ if !s.eval(s.getFormat(typename(fld.Type())), fld, 0) { // TODO is 0 index correct?
+ s.restore(mark)
+ }
+ }
+ errors <- nil // no errors
+ }()
+
+ err := <-errors
+ return s.output.Bytes(), err
+}
+
+
+// ----------------------------------------------------------------------------
+// Convenience functions
+
+// Fprint formats each argument according to the format f
+// and writes to w. The result is the total number of bytes
+// written and an os.Error, if any.
+//
+func (f Format) Fprint(w io.Writer, env Environment, args ...interface{}) (int, os.Error) {
+ data, err := f.Eval(env, args...)
+ if err != nil {
+ // TODO should we print partial result in case of error?
+ return 0, err
+ }
+ return w.Write(data)
+}
+
+
+// Print formats each argument according to the format f
+// and writes to standard output. The result is the total
+// number of bytes written and an os.Error, if any.
+//
+func (f Format) Print(args ...interface{}) (int, os.Error) {
+ return f.Fprint(os.Stdout, nil, args...)
+}
+
+
+// Sprint formats each argument according to the format f
+// and returns the resulting string. If an error occurs
+// during formatting, the result string contains the
+// partially formatted result followed by an error message.
+//
+func (f Format) Sprint(args ...interface{}) string {
+ var buf bytes.Buffer
+ _, err := f.Fprint(&buf, nil, args...)
+ if err != nil {
+ var i interface{} = args
+ fmt.Fprintf(&buf, "--- Sprint(%s) failed: %v", fmt.Sprint(i), err)
+ }
+ return buf.String()
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/* The datafmt package implements syntax-directed, type-driven formatting
+ of arbitrary data structures. Formatting a data structure consists of
+ two phases: first, a parser reads a format specification and builds a
+ "compiled" format. Then, the format can be applied repeatedly to
+ arbitrary values. Applying a format to a value evaluates to a []byte
+ containing the formatted value bytes, or nil.
+
+ A format specification is a set of package declarations and format rules:
+
+ Format = [ Entry { ";" Entry } [ ";" ] ] .
+ Entry = PackageDecl | FormatRule .
+
+ (The syntax of a format specification is presented in the same EBNF
+ notation as used in the Go language specification. The syntax of white
+ space, comments, identifiers, and string literals is the same as in Go.)
+
+ A package declaration binds a package name (such as 'ast') to a
+ package import path (such as '"go/ast"'). Each package used (in
+ a type name, see below) must be declared once before use.
+
+ PackageDecl = PackageName ImportPath .
+ PackageName = identifier .
+ ImportPath = string .
+
+ A format rule binds a rule name to a format expression. A rule name
+ may be a type name or one of the special names 'default' or '/'.
+ A type name may be the name of a predeclared type (for example, 'int',
+ 'float32', etc.), the package-qualified name of a user-defined type
+ (for example, 'ast.MapType'), or an identifier indicating the structure
+ of unnamed composite types ('array', 'chan', 'func', 'interface', 'map',
+ or 'ptr'). Each rule must have a unique name; rules can be declared in
+ any order.
+
+ FormatRule = RuleName "=" Expression .
+ RuleName = TypeName | "default" | "/" .
+ TypeName = [ PackageName "." ] identifier .
+
+ To format a value, the value's type name is used to select the format rule
+ (there is an override mechanism, see below). The format expression of the
+ selected rule specifies how the value is formatted. Each format expression,
+ when applied to a value, evaluates to a byte sequence or nil.
+
+ In its most general form, a format expression is a list of alternatives,
+ each of which is a sequence of operands:
+
+ Expression = [ Sequence ] { "|" [ Sequence ] } .
+ Sequence = Operand { Operand } .
+
+ The formatted result produced by an expression is the result of the first
+ alternative sequence that evaluates to a non-nil result; if there is no
+ such alternative, the expression evaluates to nil. The result produced by
+ an operand sequence is the concatenation of the results of its operands.
+ If any operand in the sequence evaluates to nil, the entire sequence
+ evaluates to nil.
+
+ There are five kinds of operands:
+
+ Operand = Literal | Field | Group | Option | Repetition .
+
+ Literals evaluate to themselves, with two substitutions. First,
+ %-formats expand in the manner of fmt.Printf, with the current value
+ passed as the parameter. Second, the current indentation (see below)
+ is inserted after every newline or form feed character.
+
+ Literal = string .
+
+ This table shows string literals applied to the value 42 and the
+ corresponding formatted result:
+
+ "foo" foo
+ "%x" 2a
+ "x = %d" x = 42
+ "%#x = %d" 0x2a = 42
+
+ A field operand is a field name optionally followed by an alternate
+ rule name. The field name may be an identifier or one of the special
+ names @ or *.
+
+ Field = FieldName [ ":" RuleName ] .
+ FieldName = identifier | "@" | "*" .
+
+ If the field name is an identifier, the current value must be a struct,
+ and there must be a field with that name in the struct. The same lookup
+ rules apply as in the Go language (for instance, the name of an anonymous
+ field is the unqualified type name). The field name denotes the field
+ value in the struct. If the field is not found, formatting is aborted
+ and an error message is returned. (TODO consider changing the semantics
+ such that if a field is not found, it evaluates to nil).
+
+ The special name '@' denotes the current value.
+
+ The meaning of the special name '*' depends on the type of the current
+ value:
+
+ array, slice types array, slice element (inside {} only, see below)
+ interfaces value stored in interface
+ pointers value pointed to by pointer
+
+ (Implementation restriction: channel, function and map types are not
+ supported due to missing reflection support).
+
+ Fields are evaluated as follows: If the field value is nil, or an array
+ or slice element does not exist, the result is nil (see below for details
+ on array/slice elements). If the value is not nil the field value is
+ formatted (recursively) using the rule corresponding to its type name,
+ or the alternate rule name, if given.
+
+ The following example shows a complete format specification for a
+ struct 'myPackage.Point'. Assume the package
+
+ package myPackage // in directory myDir/myPackage
+ type Point struct {
+ name string;
+ x, y int;
+ }
+
+ Applying the format specification
+
+ myPackage "myDir/myPackage";
+ int = "%d";
+ hexInt = "0x%x";
+ string = "---%s---";
+ myPackage.Point = name "{" x ", " y:hexInt "}";
+
+ to the value myPackage.Point{"foo", 3, 15} results in
+
+ ---foo---{3, 0xf}
+
+ Finally, an operand may be a grouped, optional, or repeated expression.
+ A grouped expression ("group") groups a more complex expression (body)
+ so that it can be used in place of a single operand:
+
+ Group = "(" [ Indentation ">>" ] Body ")" .
+ Indentation = Expression .
+ Body = Expression .
+
+ A group body may be prefixed by an indentation expression followed by '>>'.
+ The indentation expression is applied to the current value like any other
+ expression and the result, if not nil, is appended to the current indentation
+ during the evaluation of the body (see also formatting state, below).
+
+ An optional expression ("option") is enclosed in '[]' brackets.
+
+ Option = "[" Body "]" .
+
+ An option evaluates to its body, except that if the body evaluates to nil,
+ the option expression evaluates to an empty []byte. Thus an option's purpose
+ is to protect the expression containing the option from a nil operand.
+
+ A repeated expression ("repetition") is enclosed in '{}' braces.
+
+ Repetition = "{" Body [ "/" Separator ] "}" .
+ Separator = Expression .
+
+ A repeated expression is evaluated as follows: The body is evaluated
+ repeatedly and its results are concatenated until the body evaluates
+ to nil. The result of the repetition is the (possibly empty) concatenation,
+ but it is never nil. An implicit index is supplied for the evaluation of
+ the body: that index is used to address elements of arrays or slices. If
+ the corresponding elements do not exist, the field denoting the element
+ evaluates to nil (which in turn may terminate the repetition).
+
+ The body of a repetition may be followed by a '/' and a "separator"
+ expression. If the separator is present, it is invoked between repetitions
+ of the body.
+
+ The following example shows a complete format specification for formatting
+ a slice of unnamed type. Applying the specification
+
+ int = "%b";
+ array = { * / ", " }; // array is the type name for an unnamed slice
+
+ to the value '[]int{2, 3, 5, 7}' results in
+
+ 10, 11, 101, 111
+
+ Default rule: If a format rule named 'default' is present, it is used for
+ formatting a value if no other rule was found. A common default rule is
+
+ default = "%v"
+
+ to provide default formatting for basic types without having to specify
+ a specific rule for each basic type.
+
+ Global separator rule: If a format rule named '/' is present, it is
+ invoked with the current value between literals. If the separator
+ expression evaluates to nil, it is ignored.
+
+ For instance, a global separator rule may be used to punctuate a sequence
+ of values with commas. The rules:
+
+ default = "%v";
+ / = ", ";
+
+ will format an argument list by printing each one in its default format,
+ separated by a comma and a space.
+*/
+package datafmt
+
+import (
+ "bytes"
+ "fmt"
+ "go/token"
+ "io"
+ "os"
+ "reflect"
+ "runtime"
+)
+
+
+// ----------------------------------------------------------------------------
+// Format representation
+
+// Custom formatters implement the Formatter function type.
+// A formatter is invoked with the current formatting state, the
+// value to format, and the rule name under which the formatter
+// was installed (the same formatter function may be installed
+// under different names). The formatter may access the current state
+// to guide formatting and use State.Write to append to the state's
+// output.
+//
+// A formatter must return a boolean value indicating if it evaluated
+// to a non-nil value (true), or a nil value (false).
+//
+type Formatter func(state *State, value interface{}, ruleName string) bool
+
+
+// A FormatterMap is a set of custom formatters.
+// It maps a rule name to a formatter function.
+//
+type FormatterMap map[string]Formatter
+
+
+// A parsed format expression is built from the following nodes.
+//
+type (
+ expr interface{}
+
+ alternatives []expr // x | y | z
+
+ sequence []expr // x y z
+
+ literal [][]byte // a list of string segments, possibly starting with '%'
+
+ field struct {
+ fieldName string // including "@", "*"
+ ruleName string // "" if no rule name specified
+ }
+
+ group struct {
+ indent, body expr // (indent >> body)
+ }
+
+ option struct {
+ body expr // [body]
+ }
+
+ repetition struct {
+ body, separator expr // {body / separator}
+ }
+
+ custom struct {
+ ruleName string
+ fun Formatter
+ }
+)
+
+
+// A Format is the result of parsing a format specification.
+// The format may be applied repeatedly to format values.
+//
+type Format map[string]expr
+
+
+// ----------------------------------------------------------------------------
+// Formatting
+
+// An application-specific environment may be provided to Format.Apply;
+// the environment is available inside custom formatters via State.Env().
+// Environments must implement copying; the Copy method must return an
+// complete copy of the receiver. This is necessary so that the formatter
+// can save and restore an environment (in case of an absent expression).
+//
+// If the Environment doesn't change during formatting (this is under
+// control of the custom formatters), the Copy function can simply return
+// the receiver, and thus can be very light-weight.
+//
+type Environment interface {
+ Copy() Environment
+}
+
+
+// State represents the current formatting state.
+// It is provided as argument to custom formatters.
+//
+type State struct {
+ fmt Format // format in use
+ env Environment // user-supplied environment
+ errors chan os.Error // not chan *Error (errors <- nil would be wrong!)
+ hasOutput bool // true after the first literal has been written
+ indent bytes.Buffer // current indentation
+ output bytes.Buffer // format output
+ linePos token.Position // position of line beginning (Column == 0)
+ default_ expr // possibly nil
+ separator expr // possibly nil
+}
+
+
+func newState(fmt Format, env Environment, errors chan os.Error) *State {
+ s := new(State)
+ s.fmt = fmt
+ s.env = env
+ s.errors = errors
+ s.linePos = token.Position{Line: 1}
+
+ // if we have a default rule, cache it's expression for fast access
+ if x, found := fmt["default"]; found {
+ s.default_ = x
+ }
+
+ // if we have a global separator rule, cache it's expression for fast access
+ if x, found := fmt["/"]; found {
+ s.separator = x
+ }
+
+ return s
+}
+
+
+// Env returns the environment passed to Format.Apply.
+func (s *State) Env() interface{} { return s.env }
+
+
+// LinePos returns the position of the current line beginning
+// in the state's output buffer. Line numbers start at 1.
+//
+func (s *State) LinePos() token.Position { return s.linePos }
+
+
+// Pos returns the position of the next byte to be written to the
+// output buffer. Line numbers start at 1.
+//
+func (s *State) Pos() token.Position {
+ offs := s.output.Len()
+ return token.Position{Line: s.linePos.Line, Column: offs - s.linePos.Offset, Offset: offs}
+}
+
+
+// Write writes data to the output buffer, inserting the indentation
+// string after each newline or form feed character. It cannot return an error.
+//
+func (s *State) Write(data []byte) (int, os.Error) {
+ n := 0
+ i0 := 0
+ for i, ch := range data {
+ if ch == '\n' || ch == '\f' {
+ // write text segment and indentation
+ n1, _ := s.output.Write(data[i0 : i+1])
+ n2, _ := s.output.Write(s.indent.Bytes())
+ n += n1 + n2
+ i0 = i + 1
+ s.linePos.Offset = s.output.Len()
+ s.linePos.Line++
+ }
+ }
+ n3, _ := s.output.Write(data[i0:])
+ return n + n3, nil
+}
+
+
+type checkpoint struct {
+ env Environment
+ hasOutput bool
+ outputLen int
+ linePos token.Position
+}
+
+
+func (s *State) save() checkpoint {
+ saved := checkpoint{nil, s.hasOutput, s.output.Len(), s.linePos}
+ if s.env != nil {
+ saved.env = s.env.Copy()
+ }
+ return saved
+}
+
+
+func (s *State) restore(m checkpoint) {
+ s.env = m.env
+ s.output.Truncate(m.outputLen)
+}
+
+
+func (s *State) error(msg string) {
+ s.errors <- os.NewError(msg)
+ runtime.Goexit()
+}
+
+
+// TODO At the moment, unnamed types are simply mapped to the default
+// names below. For instance, all unnamed arrays are mapped to
+// 'array' which is not really sufficient. Eventually one may want
+// to be able to specify rules for say an unnamed slice of T.
+//
+
+func typename(typ reflect.Type) string {
+ switch typ.Kind() {
+ case reflect.Array:
+ return "array"
+ case reflect.Slice:
+ return "array"
+ case reflect.Chan:
+ return "chan"
+ case reflect.Func:
+ return "func"
+ case reflect.Interface:
+ return "interface"
+ case reflect.Map:
+ return "map"
+ case reflect.Ptr:
+ return "ptr"
+ }
+ return typ.String()
+}
+
+func (s *State) getFormat(name string) expr {
+ if fexpr, found := s.fmt[name]; found {
+ return fexpr
+ }
+
+ if s.default_ != nil {
+ return s.default_
+ }
+
+ s.error(fmt.Sprintf("no format rule for type: '%s'", name))
+ return nil
+}
+
+
+// eval applies a format expression fexpr to a value. If the expression
+// evaluates internally to a non-nil []byte, that slice is appended to
+// the state's output buffer and eval returns true. Otherwise, eval
+// returns false and the state remains unchanged.
+//
+func (s *State) eval(fexpr expr, value reflect.Value, index int) bool {
+ // an empty format expression always evaluates
+ // to a non-nil (but empty) []byte
+ if fexpr == nil {
+ return true
+ }
+
+ switch t := fexpr.(type) {
+ case alternatives:
+ // append the result of the first alternative that evaluates to
+ // a non-nil []byte to the state's output
+ mark := s.save()
+ for _, x := range t {
+ if s.eval(x, value, index) {
+ return true
+ }
+ s.restore(mark)
+ }
+ return false
+
+ case sequence:
+ // append the result of all operands to the state's output
+ // unless a nil result is encountered
+ mark := s.save()
+ for _, x := range t {
+ if !s.eval(x, value, index) {
+ s.restore(mark)
+ return false
+ }
+ }
+ return true
+
+ case literal:
+ // write separator, if any
+ if s.hasOutput {
+ // not the first literal
+ if s.separator != nil {
+ sep := s.separator // save current separator
+ s.separator = nil // and disable it (avoid recursion)
+ mark := s.save()
+ if !s.eval(sep, value, index) {
+ s.restore(mark)
+ }
+ s.separator = sep // enable it again
+ }
+ }
+ s.hasOutput = true
+ // write literal segments
+ for _, lit := range t {
+ if len(lit) > 1 && lit[0] == '%' {
+ // segment contains a %-format at the beginning
+ if lit[1] == '%' {
+ // "%%" is printed as a single "%"
+ s.Write(lit[1:])
+ } else {
+ // use s instead of s.output to get indentation right
+ fmt.Fprintf(s, string(lit), value.Interface())
+ }
+ } else {
+ // segment contains no %-formats
+ s.Write(lit)
+ }
+ }
+ return true // a literal never evaluates to nil
+
+ case *field:
+ // determine field value
+ switch t.fieldName {
+ case "@":
+ // field value is current value
+
+ case "*":
+ // indirection: operation is type-specific
+ switch v := value; v.Kind() {
+ case reflect.Array:
+ if v.Len() <= index {
+ return false
+ }
+ value = v.Index(index)
+
+ case reflect.Slice:
+ if v.IsNil() || v.Len() <= index {
+ return false
+ }
+ value = v.Index(index)
+
+ case reflect.Map:
+ s.error("reflection support for maps incomplete")
+
+ case reflect.Ptr:
+ if v.IsNil() {
+ return false
+ }
+ value = v.Elem()
+
+ case reflect.Interface:
+ if v.IsNil() {
+ return false
+ }
+ value = v.Elem()
+
+ case reflect.Chan:
+ s.error("reflection support for chans incomplete")
+
+ case reflect.Func:
+ s.error("reflection support for funcs incomplete")
+
+ default:
+ s.error(fmt.Sprintf("error: * does not apply to `%s`", value.Type()))
+ }
+
+ default:
+ // value is value of named field
+ var field reflect.Value
+ if sval := value; sval.Kind() == reflect.Struct {
+ field = sval.FieldByName(t.fieldName)
+ if !field.IsValid() {
+ // TODO consider just returning false in this case
+ s.error(fmt.Sprintf("error: no field `%s` in `%s`", t.fieldName, value.Type()))
+ }
+ }
+ value = field
+ }
+
+ // determine rule
+ ruleName := t.ruleName
+ if ruleName == "" {
+ // no alternate rule name, value type determines rule
+ ruleName = typename(value.Type())
+ }
+ fexpr = s.getFormat(ruleName)
+
+ mark := s.save()
+ if !s.eval(fexpr, value, index) {
+ s.restore(mark)
+ return false
+ }
+ return true
+
+ case *group:
+ // remember current indentation
+ indentLen := s.indent.Len()
+
+ // update current indentation
+ mark := s.save()
+ s.eval(t.indent, value, index)
+ // if the indentation evaluates to nil, the state's output buffer
+ // didn't change - either way it's ok to append the difference to
+ // the current identation
+ s.indent.Write(s.output.Bytes()[mark.outputLen:s.output.Len()])
+ s.restore(mark)
+
+ // format group body
+ mark = s.save()
+ b := true
+ if !s.eval(t.body, value, index) {
+ s.restore(mark)
+ b = false
+ }
+
+ // reset indentation
+ s.indent.Truncate(indentLen)
+ return b
+
+ case *option:
+ // evaluate the body and append the result to the state's output
+ // buffer unless the result is nil
+ mark := s.save()
+ if !s.eval(t.body, value, 0) { // TODO is 0 index correct?
+ s.restore(mark)
+ }
+ return true // an option never evaluates to nil
+
+ case *repetition:
+ // evaluate the body and append the result to the state's output
+ // buffer until a result is nil
+ for i := 0; ; i++ {
+ mark := s.save()
+ // write separator, if any
+ if i > 0 && t.separator != nil {
+ // nil result from separator is ignored
+ mark := s.save()
+ if !s.eval(t.separator, value, i) {
+ s.restore(mark)
+ }
+ }
+ if !s.eval(t.body, value, i) {
+ s.restore(mark)
+ break
+ }
+ }
+ return true // a repetition never evaluates to nil
+
+ case *custom:
+ // invoke the custom formatter to obtain the result
+ mark := s.save()
+ if !t.fun(s, value.Interface(), t.ruleName) {
+ s.restore(mark)
+ return false
+ }
+ return true
+ }
+
+ panic("unreachable")
+ return false
+}
+
+
+// Eval formats each argument according to the format
+// f and returns the resulting []byte and os.Error. If
+// an error occurred, the []byte contains the partially
+// formatted result. An environment env may be passed
+// in which is available in custom formatters through
+// the state parameter.
+//
+func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) {
+ if f == nil {
+ return nil, os.NewError("format is nil")
+ }
+
+ errors := make(chan os.Error)
+ s := newState(f, env, errors)
+
+ go func() {
+ for _, v := range args {
+ fld := reflect.NewValue(v)
+ if !fld.IsValid() {
+ errors <- os.NewError("nil argument")
+ return
+ }
+ mark := s.save()
+ if !s.eval(s.getFormat(typename(fld.Type())), fld, 0) { // TODO is 0 index correct?
+ s.restore(mark)
+ }
+ }
+ errors <- nil // no errors
+ }()
+
+ err := <-errors
+ return s.output.Bytes(), err
+}
+
+
+// ----------------------------------------------------------------------------
+// Convenience functions
+
+// Fprint formats each argument according to the format f
+// and writes to w. The result is the total number of bytes
+// written and an os.Error, if any.
+//
+func (f Format) Fprint(w io.Writer, env Environment, args ...interface{}) (int, os.Error) {
+ data, err := f.Eval(env, args...)
+ if err != nil {
+ // TODO should we print partial result in case of error?
+ return 0, err
+ }
+ return w.Write(data)
+}
+
+
+// Print formats each argument according to the format f
+// and writes to standard output. The result is the total
+// number of bytes written and an os.Error, if any.
+//
+func (f Format) Print(args ...interface{}) (int, os.Error) {
+ return f.Fprint(os.Stdout, nil, args...)
+}
+
+
+// Sprint formats each argument according to the format f
+// and returns the resulting string. If an error occurs
+// during formatting, the result string contains the
+// partially formatted result followed by an error message.
+//
+func (f Format) Sprint(args ...interface{}) string {
+ var buf bytes.Buffer
+ _, err := f.Fprint(&buf, nil, args...)
+ if err != nil {
+ var i interface{} = args
+ fmt.Fprintf(&buf, "--- Sprint(%s) failed: %v", fmt.Sprint(i), err)
+ }
+ return buf.String()
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Represents JSON data structure using native Go types: booleans, floats,
+// strings, arrays, and maps.
+
+package json
+
+import (
+ "container/vector"
+ "encoding/base64"
+ "os"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf16"
+ "utf8"
+)
+
+// Unmarshal parses the JSON-encoded data and stores the result
+// in the value pointed to by v.
+//
+// Unmarshal traverses the value v recursively.
+// If an encountered value implements the Unmarshaler interface,
+// Unmarshal calls its UnmarshalJSON method with a well-formed
+// JSON encoding.
+//
+// Otherwise, Unmarshal uses the inverse of the encodings that
+// Marshal uses, allocating maps, slices, and pointers as necessary,
+// with the following additional rules:
+//
+// To unmarshal a JSON value into a nil interface value, the
+// type stored in the interface value is one of:
+//
+// bool, for JSON booleans
+// float64, for JSON numbers
+// string, for JSON strings
+// []interface{}, for JSON arrays
+// map[string]interface{}, for JSON objects
+// nil for JSON null
+//
+// If a JSON value is not appropriate for a given target type,
+// or if a JSON number overflows the target type, Unmarshal
+// skips that field and completes the unmarshalling as best it can.
+// If no more serious errors are encountered, Unmarshal returns
+// an UnmarshalTypeError describing the earliest such error.
+//
+func Unmarshal(data []byte, v interface{}) os.Error {
+ d := new(decodeState).init(data)
+
+ // Quick check for well-formedness.
+ // Avoids filling out half a data structure
+ // before discovering a JSON syntax error.
+ err := checkValid(data, &d.scan)
+ if err != nil {
+ return err
+ }
+
+ return d.unmarshal(v)
+}
+
+// Unmarshaler is the interface implemented by objects
+// that can unmarshal a JSON description of themselves.
+// The input can be assumed to be a valid JSON object
+// encoding. UnmarshalJSON must copy the JSON data
+// if it wishes to retain the data after returning.
+type Unmarshaler interface {
+ UnmarshalJSON([]byte) os.Error
+}
+
+
+// An UnmarshalTypeError describes a JSON value that was
+// not appropriate for a value of a specific Go type.
+type UnmarshalTypeError struct {
+ Value string // description of JSON value - "bool", "array", "number -5"
+ Type reflect.Type // type of Go value it could not be assigned to
+}
+
+func (e *UnmarshalTypeError) String() string {
+ return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String()
+}
+
+// An UnmarshalFieldError describes a JSON object key that
+// led to an unexported (and therefore unwritable) struct field.
+type UnmarshalFieldError struct {
+ Key string
+ Type *reflect.StructType
+ Field reflect.StructField
+}
+
+func (e *UnmarshalFieldError) String() string {
+ return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String()
+}
+
+// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal.
+// (The argument to Unmarshal must be a non-nil pointer.)
+type InvalidUnmarshalError struct {
+ Type reflect.Type
+}
+
+func (e *InvalidUnmarshalError) String() string {
+ if e.Type == nil {
+ return "json: Unmarshal(nil)"
+ }
+
+ if _, ok := e.Type.(*reflect.PtrType); !ok {
+ return "json: Unmarshal(non-pointer " + e.Type.String() + ")"
+ }
+ return "json: Unmarshal(nil " + e.Type.String() + ")"
+}
+
+func (d *decodeState) unmarshal(v interface{}) (err os.Error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(runtime.Error); ok {
+ panic(r)
+ }
+ err = r.(os.Error)
+ }
+ }()
+
+ rv := reflect.NewValue(v)
+ pv, ok := rv.(*reflect.PtrValue)
+ if !ok || pv.IsNil() {
+ return &InvalidUnmarshalError{reflect.Typeof(v)}
+ }
+
+ d.scan.reset()
+ // We decode rv not pv.Elem because the Unmarshaler interface
+ // test must be applied at the top level of the value.
+ d.value(rv)
+ return d.savedError
+}
+
+// decodeState represents the state while decoding a JSON value.
+type decodeState struct {
+ data []byte
+ off int // read offset in data
+ scan scanner
+ nextscan scanner // for calls to nextValue
+ savedError os.Error
+}
+
+// errPhase is used for errors that should not happen unless
+// there is a bug in the JSON decoder or something is editing
+// the data slice while the decoder executes.
+var errPhase = os.NewError("JSON decoder out of sync - data changing underfoot?")
+
+func (d *decodeState) init(data []byte) *decodeState {
+ d.data = data
+ d.off = 0
+ d.savedError = nil
+ return d
+}
+
+// error aborts the decoding by panicking with err.
+func (d *decodeState) error(err os.Error) {
+ panic(err)
+}
+
+// saveError saves the first err it is called with,
+// for reporting at the end of the unmarshal.
+func (d *decodeState) saveError(err os.Error) {
+ if d.savedError == nil {
+ d.savedError = err
+ }
+}
+
+// next cuts off and returns the next full JSON value in d.data[d.off:].
+// The next value is known to be an object or array, not a literal.
+func (d *decodeState) next() []byte {
+ c := d.data[d.off]
+ item, rest, err := nextValue(d.data[d.off:], &d.nextscan)
+ if err != nil {
+ d.error(err)
+ }
+ d.off = len(d.data) - len(rest)
+
+ // Our scanner has seen the opening brace/bracket
+ // and thinks we're still in the middle of the object.
+ // invent a closing brace/bracket to get it out.
+ if c == '{' {
+ d.scan.step(&d.scan, '}')
+ } else {
+ d.scan.step(&d.scan, ']')
+ }
+
+ return item
+}
+
+// scanWhile processes bytes in d.data[d.off:] until it
+// receives a scan code not equal to op.
+// It updates d.off and returns the new scan code.
+func (d *decodeState) scanWhile(op int) int {
+ var newOp int
+ for {
+ if d.off >= len(d.data) {
+ newOp = d.scan.eof()
+ d.off = len(d.data) + 1 // mark processed EOF with len+1
+ } else {
+ c := int(d.data[d.off])
+ d.off++
+ newOp = d.scan.step(&d.scan, c)
+ }
+ if newOp != op {
+ break
+ }
+ }
+ return newOp
+}
+
+// value decodes a JSON value from d.data[d.off:] into the value.
+// it updates d.off to point past the decoded value.
+func (d *decodeState) value(v reflect.Value) {
+ if v == nil {
+ _, rest, err := nextValue(d.data[d.off:], &d.nextscan)
+ if err != nil {
+ d.error(err)
+ }
+ d.off = len(d.data) - len(rest)
+
+ // d.scan thinks we're still at the beginning of the item.
+ // Feed in an empty string - the shortest, simplest value -
+ // so that it knows we got to the end of the value.
+ if d.scan.step == stateRedo {
+ panic("redo")
+ }
+ d.scan.step(&d.scan, '"')
+ d.scan.step(&d.scan, '"')
+ return
+ }
+
+ switch op := d.scanWhile(scanSkipSpace); op {
+ default:
+ d.error(errPhase)
+
+ case scanBeginArray:
+ d.array(v)
+
+ case scanBeginObject:
+ d.object(v)
+
+ case scanBeginLiteral:
+ d.literal(v)
+ }
+}
+
+// indirect walks down v allocating pointers as needed,
+// until it gets to a non-pointer.
+// if it encounters an Unmarshaler, indirect stops and returns that.
+// if wantptr is true, indirect stops at the last pointer.
+func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, reflect.Value) {
+ for {
+ var isUnmarshaler bool
+ if v.Type().NumMethod() > 0 {
+ // Remember that this is an unmarshaler,
+ // but wait to return it until after allocating
+ // the pointer (if necessary).
+ _, isUnmarshaler = v.Interface().(Unmarshaler)
+ }
+
+ if iv, ok := v.(*reflect.InterfaceValue); ok && !iv.IsNil() {
+ v = iv.Elem()
+ continue
+ }
+ pv, ok := v.(*reflect.PtrValue)
+ if !ok {
+ break
+ }
+ _, isptrptr := pv.Elem().(*reflect.PtrValue)
+ if !isptrptr && wantptr && !isUnmarshaler {
+ return nil, pv
+ }
+ if pv.IsNil() {
+ pv.PointTo(reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem()))
+ }
+ if isUnmarshaler {
+ // Using v.Interface().(Unmarshaler)
+ // here means that we have to use a pointer
+ // as the struct field. We cannot use a value inside
+ // a pointer to a struct, because in that case
+ // v.Interface() is the value (x.f) not the pointer (&x.f).
+ // This is an unfortunate consequence of reflect.
+ // An alternative would be to look up the
+ // UnmarshalJSON method and return a FuncValue.
+ return v.Interface().(Unmarshaler), nil
+ }
+ v = pv.Elem()
+ }
+ return nil, v
+}
+
+// array consumes an array from d.data[d.off-1:], decoding into the value v.
+// the first byte of the array ('[') has been read already.
+func (d *decodeState) array(v reflect.Value) {
+ // Check for unmarshaler.
+ unmarshaler, pv := d.indirect(v, false)
+ if unmarshaler != nil {
+ d.off--
+ err := unmarshaler.UnmarshalJSON(d.next())
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ // Decoding into nil interface? Switch to non-reflect code.
+ iv, ok := v.(*reflect.InterfaceValue)
+ if ok {
+ iv.Set(reflect.NewValue(d.arrayInterface()))
+ return
+ }
+
+ // Check type of target.
+ av, ok := v.(reflect.ArrayOrSliceValue)
+ if !ok {
+ d.saveError(&UnmarshalTypeError{"array", v.Type()})
+ d.off--
+ d.next()
+ return
+ }
+
+ sv, _ := v.(*reflect.SliceValue)
+
+ i := 0
+ for {
+ // Look ahead for ] - can only happen on first iteration.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+
+ // Back up so d.value can have the byte we just read.
+ d.off--
+ d.scan.undo(op)
+
+ // Get element of array, growing if necessary.
+ if i >= av.Cap() && sv != nil {
+ newcap := sv.Cap() + sv.Cap()/2
+ if newcap < 4 {
+ newcap = 4
+ }
+ newv := reflect.MakeSlice(sv.Type().(*reflect.SliceType), sv.Len(), newcap)
+ reflect.Copy(newv, sv)
+ sv.Set(newv)
+ }
+ if i >= av.Len() && sv != nil {
+ // Must be slice; gave up on array during i >= av.Cap().
+ sv.SetLen(i + 1)
+ }
+
+ // Decode into element.
+ if i < av.Len() {
+ d.value(av.Elem(i))
+ } else {
+ // Ran out of fixed array: skip.
+ d.value(nil)
+ }
+ i++
+
+ // Next token must be , or ].
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+ if op != scanArrayValue {
+ d.error(errPhase)
+ }
+ }
+ if i < av.Len() {
+ if sv == nil {
+ // Array. Zero the rest.
+ z := reflect.MakeZero(av.Type().(*reflect.ArrayType).Elem())
+ for ; i < av.Len(); i++ {
+ av.Elem(i).SetValue(z)
+ }
+ } else {
+ sv.SetLen(i)
+ }
+ }
+}
+
+// matchName returns true if key should be written to a field named name.
+func matchName(key, name string) bool {
+ return strings.ToLower(key) == strings.ToLower(name)
+}
+
+// object consumes an object from d.data[d.off-1:], decoding into the value v.
+// the first byte of the object ('{') has been read already.
+func (d *decodeState) object(v reflect.Value) {
+ // Check for unmarshaler.
+ unmarshaler, pv := d.indirect(v, false)
+ if unmarshaler != nil {
+ d.off--
+ err := unmarshaler.UnmarshalJSON(d.next())
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ // Decoding into nil interface? Switch to non-reflect code.
+ iv, ok := v.(*reflect.InterfaceValue)
+ if ok {
+ iv.Set(reflect.NewValue(d.objectInterface()))
+ return
+ }
+
+ // Check type of target: struct or map[string]T
+ var (
+ mv *reflect.MapValue
+ sv *reflect.StructValue
+ )
+ switch v := v.(type) {
+ case *reflect.MapValue:
+ // map must have string type
+ t := v.Type().(*reflect.MapType)
+ if t.Key() != reflect.Typeof("") {
+ d.saveError(&UnmarshalTypeError{"object", v.Type()})
+ break
+ }
+ mv = v
+ if mv.IsNil() {
+ mv.SetValue(reflect.MakeMap(t))
+ }
+ case *reflect.StructValue:
+ sv = v
+ default:
+ d.saveError(&UnmarshalTypeError{"object", v.Type()})
+ }
+
+ if mv == nil && sv == nil {
+ d.off--
+ d.next() // skip over { } in input
+ return
+ }
+
+ for {
+ // Read opening " of string key or closing }.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ // closing } - can only happen on first iteration.
+ break
+ }
+ if op != scanBeginLiteral {
+ d.error(errPhase)
+ }
+
+ // Read string key.
+ start := d.off - 1
+ op = d.scanWhile(scanContinue)
+ item := d.data[start : d.off-1]
+ key, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+
+ // Figure out field corresponding to key.
+ var subv reflect.Value
+ if mv != nil {
+ subv = reflect.MakeZero(mv.Type().(*reflect.MapType).Elem())
+ } else {
+ var f reflect.StructField
+ var ok bool
+ st := sv.Type().(*reflect.StructType)
+ // First try for field with that tag.
+ if isValidTag(key) {
+ for i := 0; i < sv.NumField(); i++ {
+ f = st.Field(i)
+ if f.Tag == key {
+ ok = true
+ break
+ }
+ }
+ }
+ if !ok {
+ // Second, exact match.
+ f, ok = st.FieldByName(key)
+ }
+ if !ok {
+ // Third, case-insensitive match.
+ f, ok = st.FieldByNameFunc(func(s string) bool { return matchName(key, s) })
+ }
+
+ // Extract value; name must be exported.
+ if ok {
+ if f.PkgPath != "" {
+ d.saveError(&UnmarshalFieldError{key, st, f})
+ } else {
+ subv = sv.FieldByIndex(f.Index)
+ }
+ }
+ }
+
+ // Read : before value.
+ if op == scanSkipSpace {
+ op = d.scanWhile(scanSkipSpace)
+ }
+ if op != scanObjectKey {
+ d.error(errPhase)
+ }
+
+ // Read value.
+ d.value(subv)
+
+ // Write value back to map;
+ // if using struct, subv points into struct already.
+ if mv != nil {
+ mv.SetElem(reflect.NewValue(key), subv)
+ }
+
+ // Next token must be , or }.
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ break
+ }
+ if op != scanObjectValue {
+ d.error(errPhase)
+ }
+ }
+}
+
+// literal consumes a literal from d.data[d.off-1:], decoding into the value v.
+// The first byte of the literal has been read already
+// (that's how the caller knows it's a literal).
+func (d *decodeState) literal(v reflect.Value) {
+ // All bytes inside literal return scanContinue op code.
+ start := d.off - 1
+ op := d.scanWhile(scanContinue)
+
+ // Scan read one byte too far; back up.
+ d.off--
+ d.scan.undo(op)
+ item := d.data[start:d.off]
+
+ // Check for unmarshaler.
+ wantptr := item[0] == 'n' // null
+ unmarshaler, pv := d.indirect(v, wantptr)
+ if unmarshaler != nil {
+ err := unmarshaler.UnmarshalJSON(item)
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ switch c := item[0]; c {
+ case 'n': // null
+ switch v.(type) {
+ default:
+ d.saveError(&UnmarshalTypeError{"null", v.Type()})
+ case *reflect.InterfaceValue, *reflect.PtrValue, *reflect.MapValue:
+ v.SetValue(nil)
+ }
+
+ case 't', 'f': // true, false
+ value := c == 't'
+ switch v := v.(type) {
+ default:
+ d.saveError(&UnmarshalTypeError{"bool", v.Type()})
+ case *reflect.BoolValue:
+ v.Set(value)
+ case *reflect.InterfaceValue:
+ v.Set(reflect.NewValue(value))
+ }
+
+ case '"': // string
+ s, ok := unquoteBytes(item)
+ if !ok {
+ d.error(errPhase)
+ }
+ switch v := v.(type) {
+ default:
+ d.saveError(&UnmarshalTypeError{"string", v.Type()})
+ case *reflect.SliceValue:
+ if v.Type() != byteSliceType {
+ d.saveError(&UnmarshalTypeError{"string", v.Type()})
+ break
+ }
+ b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
+ n, err := base64.StdEncoding.Decode(b, s)
+ if err != nil {
+ d.saveError(err)
+ break
+ }
+ v.Set(reflect.NewValue(b[0:n]).(*reflect.SliceValue))
+ case *reflect.StringValue:
+ v.Set(string(s))
+ case *reflect.InterfaceValue:
+ v.Set(reflect.NewValue(string(s)))
+ }
+
+ default: // number
+ if c != '-' && (c < '0' || c > '9') {
+ d.error(errPhase)
+ }
+ s := string(item)
+ switch v := v.(type) {
+ default:
+ d.error(&UnmarshalTypeError{"number", v.Type()})
+ case *reflect.InterfaceValue:
+ n, err := strconv.Atof64(s)
+ if err != nil {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(reflect.NewValue(n))
+
+ case *reflect.IntValue:
+ n, err := strconv.Atoi64(s)
+ if err != nil || v.Overflow(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(n)
+
+ case *reflect.UintValue:
+ n, err := strconv.Atoui64(s)
+ if err != nil || v.Overflow(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(n)
+
+ case *reflect.FloatValue:
+ n, err := strconv.AtofN(s, v.Type().Bits())
+ if err != nil || v.Overflow(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(n)
+ }
+ }
+}
+
+// The xxxInterface routines build up a value to be stored
+// in an empty interface. They are not strictly necessary,
+// but they avoid the weight of reflection in this common case.
+
+// valueInterface is like value but returns interface{}
+func (d *decodeState) valueInterface() interface{} {
+ switch d.scanWhile(scanSkipSpace) {
+ default:
+ d.error(errPhase)
+ case scanBeginArray:
+ return d.arrayInterface()
+ case scanBeginObject:
+ return d.objectInterface()
+ case scanBeginLiteral:
+ return d.literalInterface()
+ }
+ panic("unreachable")
+}
+
+// arrayInterface is like array but returns []interface{}.
+func (d *decodeState) arrayInterface() []interface{} {
+ var v vector.Vector
+ for {
+ // Look ahead for ] - can only happen on first iteration.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+
+ // Back up so d.value can have the byte we just read.
+ d.off--
+ d.scan.undo(op)
+
+ v.Push(d.valueInterface())
+
+ // Next token must be , or ].
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+ if op != scanArrayValue {
+ d.error(errPhase)
+ }
+ }
+ return v
+}
+
+// objectInterface is like object but returns map[string]interface{}.
+func (d *decodeState) objectInterface() map[string]interface{} {
+ m := make(map[string]interface{})
+ for {
+ // Read opening " of string key or closing }.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ // closing } - can only happen on first iteration.
+ break
+ }
+ if op != scanBeginLiteral {
+ d.error(errPhase)
+ }
+
+ // Read string key.
+ start := d.off - 1
+ op = d.scanWhile(scanContinue)
+ item := d.data[start : d.off-1]
+ key, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+
+ // Read : before value.
+ if op == scanSkipSpace {
+ op = d.scanWhile(scanSkipSpace)
+ }
+ if op != scanObjectKey {
+ d.error(errPhase)
+ }
+
+ // Read value.
+ m[key] = d.valueInterface()
+
+ // Next token must be , or }.
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ break
+ }
+ if op != scanObjectValue {
+ d.error(errPhase)
+ }
+ }
+ return m
+}
+
+
+// literalInterface is like literal but returns an interface value.
+func (d *decodeState) literalInterface() interface{} {
+ // All bytes inside literal return scanContinue op code.
+ start := d.off - 1
+ op := d.scanWhile(scanContinue)
+
+ // Scan read one byte too far; back up.
+ d.off--
+ d.scan.undo(op)
+ item := d.data[start:d.off]
+
+ switch c := item[0]; c {
+ case 'n': // null
+ return nil
+
+ case 't', 'f': // true, false
+ return c == 't'
+
+ case '"': // string
+ s, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+ return s
+
+ default: // number
+ if c != '-' && (c < '0' || c > '9') {
+ d.error(errPhase)
+ }
+ n, err := strconv.Atof64(string(item))
+ if err != nil {
+ d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.Typeof(0.0)})
+ }
+ return n
+ }
+ panic("unreachable")
+}
+
+// getu4 decodes \uXXXX from the beginning of s, returning the hex value,
+// or it returns -1.
+func getu4(s []byte) int {
+ if len(s) < 6 || s[0] != '\\' || s[1] != 'u' {
+ return -1
+ }
+ rune, err := strconv.Btoui64(string(s[2:6]), 16)
+ if err != nil {
+ return -1
+ }
+ return int(rune)
+}
+
+// unquote converts a quoted JSON string literal s into an actual string t.
+// The rules are different than for Go, so cannot use strconv.Unquote.
+func unquote(s []byte) (t string, ok bool) {
+ s, ok = unquoteBytes(s)
+ t = string(s)
+ return
+}
+
+func unquoteBytes(s []byte) (t []byte, ok bool) {
+ if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' {
+ return
+ }
+ s = s[1 : len(s)-1]
+
+ // Check for unusual characters. If there are none,
+ // then no unquoting is needed, so return a slice of the
+ // original bytes.
+ r := 0
+ for r < len(s) {
+ c := s[r]
+ if c == '\\' || c == '"' || c < ' ' {
+ break
+ }
+ if c < utf8.RuneSelf {
+ r++
+ continue
+ }
+ rune, size := utf8.DecodeRune(s[r:])
+ if rune == utf8.RuneError && size == 1 {
+ break
+ }
+ r += size
+ }
+ if r == len(s) {
+ return s, true
+ }
+
+ b := make([]byte, len(s)+2*utf8.UTFMax)
+ w := copy(b, s[0:r])
+ for r < len(s) {
+ // Out of room? Can only happen if s is full of
+ // malformed UTF-8 and we're replacing each
+ // byte with RuneError.
+ if w >= len(b)-2*utf8.UTFMax {
+ nb := make([]byte, (len(b)+utf8.UTFMax)*2)
+ copy(nb, b[0:w])
+ b = nb
+ }
+ switch c := s[r]; {
+ case c == '\\':
+ r++
+ if r >= len(s) {
+ return
+ }
+ switch s[r] {
+ default:
+ return
+ case '"', '\\', '/', '\'':
+ b[w] = s[r]
+ r++
+ w++
+ case 'b':
+ b[w] = '\b'
+ r++
+ w++
+ case 'f':
+ b[w] = '\f'
+ r++
+ w++
+ case 'n':
+ b[w] = '\n'
+ r++
+ w++
+ case 'r':
+ b[w] = '\r'
+ r++
+ w++
+ case 't':
+ b[w] = '\t'
+ r++
+ w++
+ case 'u':
+ r--
+ rune := getu4(s[r:])
+ if rune < 0 {
+ return
+ }
+ r += 6
+ if utf16.IsSurrogate(rune) {
+ rune1 := getu4(s[r:])
+ if dec := utf16.DecodeRune(rune, rune1); dec != unicode.ReplacementChar {
+ // A valid pair; consume.
+ r += 6
+ w += utf8.EncodeRune(b[w:], dec)
+ break
+ }
+ // Invalid surrogate; fall back to replacement rune.
+ rune = unicode.ReplacementChar
+ }
+ w += utf8.EncodeRune(b[w:], rune)
+ }
+
+ // Quote, control characters are invalid.
+ case c == '"', c < ' ':
+ return
+
+ // ASCII
+ case c < utf8.RuneSelf:
+ b[w] = c
+ r++
+ w++
+
+ // Coerce to well-formed UTF-8.
+ default:
+ rune, size := utf8.DecodeRune(s[r:])
+ r += size
+ w += utf8.EncodeRune(b[w:], rune)
+ }
+ }
+ return b[0:w], true
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Represents JSON data structure using native Go types: booleans, floats,
+// strings, arrays, and maps.
+
+package json
+
+import (
+ "container/vector"
+ "encoding/base64"
+ "os"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf16"
+ "utf8"
+)
+
+// Unmarshal parses the JSON-encoded data and stores the result
+// in the value pointed to by v.
+//
+// Unmarshal traverses the value v recursively.
+// If an encountered value implements the Unmarshaler interface,
+// Unmarshal calls its UnmarshalJSON method with a well-formed
+// JSON encoding.
+//
+// Otherwise, Unmarshal uses the inverse of the encodings that
+// Marshal uses, allocating maps, slices, and pointers as necessary,
+// with the following additional rules:
+//
+// To unmarshal a JSON value into a nil interface value, the
+// type stored in the interface value is one of:
+//
+// bool, for JSON booleans
+// float64, for JSON numbers
+// string, for JSON strings
+// []interface{}, for JSON arrays
+// map[string]interface{}, for JSON objects
+// nil for JSON null
+//
+// If a JSON value is not appropriate for a given target type,
+// or if a JSON number overflows the target type, Unmarshal
+// skips that field and completes the unmarshalling as best it can.
+// If no more serious errors are encountered, Unmarshal returns
+// an UnmarshalTypeError describing the earliest such error.
+//
+func Unmarshal(data []byte, v interface{}) os.Error {
+ d := new(decodeState).init(data)
+
+ // Quick check for well-formedness.
+ // Avoids filling out half a data structure
+ // before discovering a JSON syntax error.
+ err := checkValid(data, &d.scan)
+ if err != nil {
+ return err
+ }
+
+ return d.unmarshal(v)
+}
+
+// Unmarshaler is the interface implemented by objects
+// that can unmarshal a JSON description of themselves.
+// The input can be assumed to be a valid JSON object
+// encoding. UnmarshalJSON must copy the JSON data
+// if it wishes to retain the data after returning.
+type Unmarshaler interface {
+ UnmarshalJSON([]byte) os.Error
+}
+
+
+// An UnmarshalTypeError describes a JSON value that was
+// not appropriate for a value of a specific Go type.
+type UnmarshalTypeError struct {
+ Value string // description of JSON value - "bool", "array", "number -5"
+ Type reflect.Type // type of Go value it could not be assigned to
+}
+
+func (e *UnmarshalTypeError) String() string {
+ return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String()
+}
+
+// An UnmarshalFieldError describes a JSON object key that
+// led to an unexported (and therefore unwritable) struct field.
+type UnmarshalFieldError struct {
+ Key string
+ Type reflect.Type
+ Field reflect.StructField
+}
+
+func (e *UnmarshalFieldError) String() string {
+ return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String()
+}
+
+// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal.
+// (The argument to Unmarshal must be a non-nil pointer.)
+type InvalidUnmarshalError struct {
+ Type reflect.Type
+}
+
+func (e *InvalidUnmarshalError) String() string {
+ if e.Type == nil {
+ return "json: Unmarshal(nil)"
+ }
+
+ if e.Type.Kind() != reflect.Ptr {
+ return "json: Unmarshal(non-pointer " + e.Type.String() + ")"
+ }
+ return "json: Unmarshal(nil " + e.Type.String() + ")"
+}
+
+func (d *decodeState) unmarshal(v interface{}) (err os.Error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(runtime.Error); ok {
+ panic(r)
+ }
+ err = r.(os.Error)
+ }
+ }()
+
+ rv := reflect.NewValue(v)
+ pv := rv
+ if pv.Kind() != reflect.Ptr ||
+ pv.IsNil() {
+ return &InvalidUnmarshalError{reflect.Typeof(v)}
+ }
+
+ d.scan.reset()
+ // We decode rv not pv.Elem because the Unmarshaler interface
+ // test must be applied at the top level of the value.
+ d.value(rv)
+ return d.savedError
+}
+
+// decodeState represents the state while decoding a JSON value.
+type decodeState struct {
+ data []byte
+ off int // read offset in data
+ scan scanner
+ nextscan scanner // for calls to nextValue
+ savedError os.Error
+}
+
+// errPhase is used for errors that should not happen unless
+// there is a bug in the JSON decoder or something is editing
+// the data slice while the decoder executes.
+var errPhase = os.NewError("JSON decoder out of sync - data changing underfoot?")
+
+func (d *decodeState) init(data []byte) *decodeState {
+ d.data = data
+ d.off = 0
+ d.savedError = nil
+ return d
+}
+
+// error aborts the decoding by panicking with err.
+func (d *decodeState) error(err os.Error) {
+ panic(err)
+}
+
+// saveError saves the first err it is called with,
+// for reporting at the end of the unmarshal.
+func (d *decodeState) saveError(err os.Error) {
+ if d.savedError == nil {
+ d.savedError = err
+ }
+}
+
+// next cuts off and returns the next full JSON value in d.data[d.off:].
+// The next value is known to be an object or array, not a literal.
+func (d *decodeState) next() []byte {
+ c := d.data[d.off]
+ item, rest, err := nextValue(d.data[d.off:], &d.nextscan)
+ if err != nil {
+ d.error(err)
+ }
+ d.off = len(d.data) - len(rest)
+
+ // Our scanner has seen the opening brace/bracket
+ // and thinks we're still in the middle of the object.
+ // invent a closing brace/bracket to get it out.
+ if c == '{' {
+ d.scan.step(&d.scan, '}')
+ } else {
+ d.scan.step(&d.scan, ']')
+ }
+
+ return item
+}
+
+// scanWhile processes bytes in d.data[d.off:] until it
+// receives a scan code not equal to op.
+// It updates d.off and returns the new scan code.
+func (d *decodeState) scanWhile(op int) int {
+ var newOp int
+ for {
+ if d.off >= len(d.data) {
+ newOp = d.scan.eof()
+ d.off = len(d.data) + 1 // mark processed EOF with len+1
+ } else {
+ c := int(d.data[d.off])
+ d.off++
+ newOp = d.scan.step(&d.scan, c)
+ }
+ if newOp != op {
+ break
+ }
+ }
+ return newOp
+}
+
+// value decodes a JSON value from d.data[d.off:] into the value.
+// it updates d.off to point past the decoded value.
+func (d *decodeState) value(v reflect.Value) {
+ if !v.IsValid() {
+ _, rest, err := nextValue(d.data[d.off:], &d.nextscan)
+ if err != nil {
+ d.error(err)
+ }
+ d.off = len(d.data) - len(rest)
+
+ // d.scan thinks we're still at the beginning of the item.
+ // Feed in an empty string - the shortest, simplest value -
+ // so that it knows we got to the end of the value.
+ if d.scan.step == stateRedo {
+ panic("redo")
+ }
+ d.scan.step(&d.scan, '"')
+ d.scan.step(&d.scan, '"')
+ return
+ }
+
+ switch op := d.scanWhile(scanSkipSpace); op {
+ default:
+ d.error(errPhase)
+
+ case scanBeginArray:
+ d.array(v)
+
+ case scanBeginObject:
+ d.object(v)
+
+ case scanBeginLiteral:
+ d.literal(v)
+ }
+}
+
+// indirect walks down v allocating pointers as needed,
+// until it gets to a non-pointer.
+// if it encounters an Unmarshaler, indirect stops and returns that.
+// if wantptr is true, indirect stops at the last pointer.
+func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, reflect.Value) {
+ for {
+ var isUnmarshaler bool
+ if v.Type().NumMethod() > 0 {
+ // Remember that this is an unmarshaler,
+ // but wait to return it until after allocating
+ // the pointer (if necessary).
+ _, isUnmarshaler = v.Interface().(Unmarshaler)
+ }
+
+ if iv := v; iv.Kind() == reflect.Interface && !iv.IsNil() {
+ v = iv.Elem()
+ continue
+ }
+ pv := v
+ if pv.Kind() != reflect.Ptr {
+ break
+ }
+
+ if pv.Elem().Kind() != reflect.Ptr &&
+ wantptr && !isUnmarshaler {
+ return nil, pv
+ }
+ if pv.IsNil() {
+ pv.Set(reflect.Zero(pv.Type().Elem()).Addr())
+ }
+ if isUnmarshaler {
+ // Using v.Interface().(Unmarshaler)
+ // here means that we have to use a pointer
+ // as the struct field. We cannot use a value inside
+ // a pointer to a struct, because in that case
+ // v.Interface() is the value (x.f) not the pointer (&x.f).
+ // This is an unfortunate consequence of reflect.
+ // An alternative would be to look up the
+ // UnmarshalJSON method and return a FuncValue.
+ return v.Interface().(Unmarshaler), reflect.Value{}
+ }
+ v = pv.Elem()
+ }
+ return nil, v
+}
+
+// array consumes an array from d.data[d.off-1:], decoding into the value v.
+// the first byte of the array ('[') has been read already.
+func (d *decodeState) array(v reflect.Value) {
+ // Check for unmarshaler.
+ unmarshaler, pv := d.indirect(v, false)
+ if unmarshaler != nil {
+ d.off--
+ err := unmarshaler.UnmarshalJSON(d.next())
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ // Decoding into nil interface? Switch to non-reflect code.
+ iv := v
+ ok := iv.Kind() == reflect.Interface
+ if ok {
+ iv.Set(reflect.NewValue(d.arrayInterface()))
+ return
+ }
+
+ // Check type of target.
+ av := v
+ if av.Kind() != reflect.Array && av.Kind() != reflect.Slice {
+ d.saveError(&UnmarshalTypeError{"array", v.Type()})
+ d.off--
+ d.next()
+ return
+ }
+
+ sv := v
+
+ i := 0
+ for {
+ // Look ahead for ] - can only happen on first iteration.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+
+ // Back up so d.value can have the byte we just read.
+ d.off--
+ d.scan.undo(op)
+
+ // Get element of array, growing if necessary.
+ if i >= av.Cap() && sv.IsValid() {
+ newcap := sv.Cap() + sv.Cap()/2
+ if newcap < 4 {
+ newcap = 4
+ }
+ newv := reflect.MakeSlice(sv.Type(), sv.Len(), newcap)
+ reflect.Copy(newv, sv)
+ sv.Set(newv)
+ }
+ if i >= av.Len() && sv.IsValid() {
+ // Must be slice; gave up on array during i >= av.Cap().
+ sv.SetLen(i + 1)
+ }
+
+ // Decode into element.
+ if i < av.Len() {
+ d.value(av.Index(i))
+ } else {
+ // Ran out of fixed array: skip.
+ d.value(reflect.Value{})
+ }
+ i++
+
+ // Next token must be , or ].
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+ if op != scanArrayValue {
+ d.error(errPhase)
+ }
+ }
+ if i < av.Len() {
+ if !sv.IsValid() {
+ // Array. Zero the rest.
+ z := reflect.Zero(av.Type().Elem())
+ for ; i < av.Len(); i++ {
+ av.Index(i).Set(z)
+ }
+ } else {
+ sv.SetLen(i)
+ }
+ }
+}
+
+// matchName returns true if key should be written to a field named name.
+func matchName(key, name string) bool {
+ return strings.ToLower(key) == strings.ToLower(name)
+}
+
+// object consumes an object from d.data[d.off-1:], decoding into the value v.
+// the first byte of the object ('{') has been read already.
+func (d *decodeState) object(v reflect.Value) {
+ // Check for unmarshaler.
+ unmarshaler, pv := d.indirect(v, false)
+ if unmarshaler != nil {
+ d.off--
+ err := unmarshaler.UnmarshalJSON(d.next())
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ // Decoding into nil interface? Switch to non-reflect code.
+ iv := v
+ if iv.Kind() == reflect.Interface {
+ iv.Set(reflect.NewValue(d.objectInterface()))
+ return
+ }
+
+ // Check type of target: struct or map[string]T
+ var (
+ mv reflect.Value
+ sv reflect.Value
+ )
+ switch v.Kind() {
+ case reflect.Map:
+ // map must have string type
+ t := v.Type()
+ if t.Key() != reflect.Typeof("") {
+ d.saveError(&UnmarshalTypeError{"object", v.Type()})
+ break
+ }
+ mv = v
+ if mv.IsNil() {
+ mv.Set(reflect.MakeMap(t))
+ }
+ case reflect.Struct:
+ sv = v
+ default:
+ d.saveError(&UnmarshalTypeError{"object", v.Type()})
+ }
+
+ if !mv.IsValid() && !sv.IsValid() {
+ d.off--
+ d.next() // skip over { } in input
+ return
+ }
+
+ for {
+ // Read opening " of string key or closing }.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ // closing } - can only happen on first iteration.
+ break
+ }
+ if op != scanBeginLiteral {
+ d.error(errPhase)
+ }
+
+ // Read string key.
+ start := d.off - 1
+ op = d.scanWhile(scanContinue)
+ item := d.data[start : d.off-1]
+ key, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+
+ // Figure out field corresponding to key.
+ var subv reflect.Value
+ if mv.IsValid() {
+ subv = reflect.Zero(mv.Type().Elem())
+ } else {
+ var f reflect.StructField
+ var ok bool
+ st := sv.Type()
+ // First try for field with that tag.
+ if isValidTag(key) {
+ for i := 0; i < sv.NumField(); i++ {
+ f = st.Field(i)
+ if f.Tag == key {
+ ok = true
+ break
+ }
+ }
+ }
+ if !ok {
+ // Second, exact match.
+ f, ok = st.FieldByName(key)
+ }
+ if !ok {
+ // Third, case-insensitive match.
+ f, ok = st.FieldByNameFunc(func(s string) bool { return matchName(key, s) })
+ }
+
+ // Extract value; name must be exported.
+ if ok {
+ if f.PkgPath != "" {
+ d.saveError(&UnmarshalFieldError{key, st, f})
+ } else {
+ subv = sv.FieldByIndex(f.Index)
+ }
+ }
+ }
+
+ // Read : before value.
+ if op == scanSkipSpace {
+ op = d.scanWhile(scanSkipSpace)
+ }
+ if op != scanObjectKey {
+ d.error(errPhase)
+ }
+
+ // Read value.
+ d.value(subv)
+
+ // Write value back to map;
+ // if using struct, subv points into struct already.
+ if mv.IsValid() {
+ mv.SetMapIndex(reflect.NewValue(key), subv)
+ }
+
+ // Next token must be , or }.
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ break
+ }
+ if op != scanObjectValue {
+ d.error(errPhase)
+ }
+ }
+}
+
+// literal consumes a literal from d.data[d.off-1:], decoding into the value v.
+// The first byte of the literal has been read already
+// (that's how the caller knows it's a literal).
+func (d *decodeState) literal(v reflect.Value) {
+ // All bytes inside literal return scanContinue op code.
+ start := d.off - 1
+ op := d.scanWhile(scanContinue)
+
+ // Scan read one byte too far; back up.
+ d.off--
+ d.scan.undo(op)
+ item := d.data[start:d.off]
+
+ // Check for unmarshaler.
+ wantptr := item[0] == 'n' // null
+ unmarshaler, pv := d.indirect(v, wantptr)
+ if unmarshaler != nil {
+ err := unmarshaler.UnmarshalJSON(item)
+ if err != nil {
+ d.error(err)
+ }
+ return
+ }
+ v = pv
+
+ switch c := item[0]; c {
+ case 'n': // null
+ switch v.Kind() {
+ default:
+ d.saveError(&UnmarshalTypeError{"null", v.Type()})
+ case reflect.Interface, reflect.Ptr, reflect.Map:
+ v.Set(reflect.Zero(v.Type()))
+ }
+
+ case 't', 'f': // true, false
+ value := c == 't'
+ switch v.Kind() {
+ default:
+ d.saveError(&UnmarshalTypeError{"bool", v.Type()})
+ case reflect.Bool:
+ v.SetBool(value)
+ case reflect.Interface:
+ v.Set(reflect.NewValue(value))
+ }
+
+ case '"': // string
+ s, ok := unquoteBytes(item)
+ if !ok {
+ d.error(errPhase)
+ }
+ switch v.Kind() {
+ default:
+ d.saveError(&UnmarshalTypeError{"string", v.Type()})
+ case reflect.Slice:
+ if v.Type() != byteSliceType {
+ d.saveError(&UnmarshalTypeError{"string", v.Type()})
+ break
+ }
+ b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
+ n, err := base64.StdEncoding.Decode(b, s)
+ if err != nil {
+ d.saveError(err)
+ break
+ }
+ v.Set(reflect.NewValue(b[0:n]))
+ case reflect.String:
+ v.SetString(string(s))
+ case reflect.Interface:
+ v.Set(reflect.NewValue(string(s)))
+ }
+
+ default: // number
+ if c != '-' && (c < '0' || c > '9') {
+ d.error(errPhase)
+ }
+ s := string(item)
+ switch v.Kind() {
+ default:
+ d.error(&UnmarshalTypeError{"number", v.Type()})
+ case reflect.Interface:
+ n, err := strconv.Atof64(s)
+ if err != nil {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.Set(reflect.NewValue(n))
+
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ n, err := strconv.Atoi64(s)
+ if err != nil || v.OverflowInt(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.SetInt(n)
+
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ n, err := strconv.Atoui64(s)
+ if err != nil || v.OverflowUint(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.SetUint(n)
+
+ case reflect.Float32, reflect.Float64:
+ n, err := strconv.AtofN(s, v.Type().Bits())
+ if err != nil || v.OverflowFloat(n) {
+ d.saveError(&UnmarshalTypeError{"number " + s, v.Type()})
+ break
+ }
+ v.SetFloat(n)
+ }
+ }
+}
+
+// The xxxInterface routines build up a value to be stored
+// in an empty interface. They are not strictly necessary,
+// but they avoid the weight of reflection in this common case.
+
+// valueInterface is like value but returns interface{}
+func (d *decodeState) valueInterface() interface{} {
+ switch d.scanWhile(scanSkipSpace) {
+ default:
+ d.error(errPhase)
+ case scanBeginArray:
+ return d.arrayInterface()
+ case scanBeginObject:
+ return d.objectInterface()
+ case scanBeginLiteral:
+ return d.literalInterface()
+ }
+ panic("unreachable")
+}
+
+// arrayInterface is like array but returns []interface{}.
+func (d *decodeState) arrayInterface() []interface{} {
+ var v vector.Vector
+ for {
+ // Look ahead for ] - can only happen on first iteration.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+
+ // Back up so d.value can have the byte we just read.
+ d.off--
+ d.scan.undo(op)
+
+ v.Push(d.valueInterface())
+
+ // Next token must be , or ].
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndArray {
+ break
+ }
+ if op != scanArrayValue {
+ d.error(errPhase)
+ }
+ }
+ return v
+}
+
+// objectInterface is like object but returns map[string]interface{}.
+func (d *decodeState) objectInterface() map[string]interface{} {
+ m := make(map[string]interface{})
+ for {
+ // Read opening " of string key or closing }.
+ op := d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ // closing } - can only happen on first iteration.
+ break
+ }
+ if op != scanBeginLiteral {
+ d.error(errPhase)
+ }
+
+ // Read string key.
+ start := d.off - 1
+ op = d.scanWhile(scanContinue)
+ item := d.data[start : d.off-1]
+ key, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+
+ // Read : before value.
+ if op == scanSkipSpace {
+ op = d.scanWhile(scanSkipSpace)
+ }
+ if op != scanObjectKey {
+ d.error(errPhase)
+ }
+
+ // Read value.
+ m[key] = d.valueInterface()
+
+ // Next token must be , or }.
+ op = d.scanWhile(scanSkipSpace)
+ if op == scanEndObject {
+ break
+ }
+ if op != scanObjectValue {
+ d.error(errPhase)
+ }
+ }
+ return m
+}
+
+
+// literalInterface is like literal but returns an interface value.
+func (d *decodeState) literalInterface() interface{} {
+ // All bytes inside literal return scanContinue op code.
+ start := d.off - 1
+ op := d.scanWhile(scanContinue)
+
+ // Scan read one byte too far; back up.
+ d.off--
+ d.scan.undo(op)
+ item := d.data[start:d.off]
+
+ switch c := item[0]; c {
+ case 'n': // null
+ return nil
+
+ case 't', 'f': // true, false
+ return c == 't'
+
+ case '"': // string
+ s, ok := unquote(item)
+ if !ok {
+ d.error(errPhase)
+ }
+ return s
+
+ default: // number
+ if c != '-' && (c < '0' || c > '9') {
+ d.error(errPhase)
+ }
+ n, err := strconv.Atof64(string(item))
+ if err != nil {
+ d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.Typeof(0.0)})
+ }
+ return n
+ }
+ panic("unreachable")
+}
+
+// getu4 decodes \uXXXX from the beginning of s, returning the hex value,
+// or it returns -1.
+func getu4(s []byte) int {
+ if len(s) < 6 || s[0] != '\\' || s[1] != 'u' {
+ return -1
+ }
+ rune, err := strconv.Btoui64(string(s[2:6]), 16)
+ if err != nil {
+ return -1
+ }
+ return int(rune)
+}
+
+// unquote converts a quoted JSON string literal s into an actual string t.
+// The rules are different than for Go, so cannot use strconv.Unquote.
+func unquote(s []byte) (t string, ok bool) {
+ s, ok = unquoteBytes(s)
+ t = string(s)
+ return
+}
+
+func unquoteBytes(s []byte) (t []byte, ok bool) {
+ if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' {
+ return
+ }
+ s = s[1 : len(s)-1]
+
+ // Check for unusual characters. If there are none,
+ // then no unquoting is needed, so return a slice of the
+ // original bytes.
+ r := 0
+ for r < len(s) {
+ c := s[r]
+ if c == '\\' || c == '"' || c < ' ' {
+ break
+ }
+ if c < utf8.RuneSelf {
+ r++
+ continue
+ }
+ rune, size := utf8.DecodeRune(s[r:])
+ if rune == utf8.RuneError && size == 1 {
+ break
+ }
+ r += size
+ }
+ if r == len(s) {
+ return s, true
+ }
+
+ b := make([]byte, len(s)+2*utf8.UTFMax)
+ w := copy(b, s[0:r])
+ for r < len(s) {
+ // Out of room? Can only happen if s is full of
+ // malformed UTF-8 and we're replacing each
+ // byte with RuneError.
+ if w >= len(b)-2*utf8.UTFMax {
+ nb := make([]byte, (len(b)+utf8.UTFMax)*2)
+ copy(nb, b[0:w])
+ b = nb
+ }
+ switch c := s[r]; {
+ case c == '\\':
+ r++
+ if r >= len(s) {
+ return
+ }
+ switch s[r] {
+ default:
+ return
+ case '"', '\\', '/', '\'':
+ b[w] = s[r]
+ r++
+ w++
+ case 'b':
+ b[w] = '\b'
+ r++
+ w++
+ case 'f':
+ b[w] = '\f'
+ r++
+ w++
+ case 'n':
+ b[w] = '\n'
+ r++
+ w++
+ case 'r':
+ b[w] = '\r'
+ r++
+ w++
+ case 't':
+ b[w] = '\t'
+ r++
+ w++
+ case 'u':
+ r--
+ rune := getu4(s[r:])
+ if rune < 0 {
+ return
+ }
+ r += 6
+ if utf16.IsSurrogate(rune) {
+ rune1 := getu4(s[r:])
+ if dec := utf16.DecodeRune(rune, rune1); dec != unicode.ReplacementChar {
+ // A valid pair; consume.
+ r += 6
+ w += utf8.EncodeRune(b[w:], dec)
+ break
+ }
+ // Invalid surrogate; fall back to replacement rune.
+ rune = unicode.ReplacementChar
+ }
+ w += utf8.EncodeRune(b[w:], rune)
+ }
+
+ // Quote, control characters are invalid.
+ case c == '"', c < ' ':
+ return
+
+ // ASCII
+ case c < utf8.RuneSelf:
+ b[w] = c
+ r++
+ w++
+
+ // Coerce to well-formed UTF-8.
+ default:
+ rune, size := utf8.DecodeRune(s[r:])
+ r += size
+ w += utf8.EncodeRune(b[w:], rune)
+ }
+ }
+ return b[0:w], true
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "bufio"
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "sync"
+)
+
+// A Decoder manages the receipt of type and data information read from the
+// remote side of a connection.
+type Decoder struct {
+ mutex sync.Mutex // each item must be received atomically
+ r io.Reader // source of the data
+ buf bytes.Buffer // buffer for more efficient i/o from r
+ wireType map[typeId]*wireType // map from remote ID to local description
+ decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
+ ignorerCache map[typeId]**decEngine // ditto for ignored objects
+ freeList *decoderState // list of free decoderStates; avoids reallocation
+ countBuf []byte // used for decoding integers while parsing messages
+ tmp []byte // temporary storage for i/o; saves reallocating
+ err os.Error
+}
+
+// NewDecoder returns a new decoder that reads from the io.Reader.
+func NewDecoder(r io.Reader) *Decoder {
+ dec := new(Decoder)
+ dec.r = bufio.NewReader(r)
+ dec.wireType = make(map[typeId]*wireType)
+ dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
+ dec.ignorerCache = make(map[typeId]**decEngine)
+ dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes
+
+ return dec
+}
+
+// recvType loads the definition of a type.
+func (dec *Decoder) recvType(id typeId) {
+ // Have we already seen this type? That's an error
+ if id < firstUserId || dec.wireType[id] != nil {
+ dec.err = os.ErrorString("gob: duplicate type received")
+ return
+ }
+
+ // Type:
+ wire := new(wireType)
+ dec.decodeValue(tWireType, reflect.NewValue(wire))
+ if dec.err != nil {
+ return
+ }
+ // Remember we've seen this type.
+ dec.wireType[id] = wire
+}
+
+// recvMessage reads the next count-delimited item from the input. It is the converse
+// of Encoder.writeMessage. It returns false on EOF or other error reading the message.
+func (dec *Decoder) recvMessage() bool {
+ // Read a count.
+ nbytes, _, err := decodeUintReader(dec.r, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ return false
+ }
+ dec.readMessage(int(nbytes))
+ return dec.err == nil
+}
+
+// readMessage reads the next nbytes bytes from the input.
+func (dec *Decoder) readMessage(nbytes int) {
+ // Allocate the buffer.
+ if cap(dec.tmp) < nbytes {
+ dec.tmp = make([]byte, nbytes+100) // room to grow
+ }
+ dec.tmp = dec.tmp[:nbytes]
+
+ // Read the data
+ _, dec.err = io.ReadFull(dec.r, dec.tmp)
+ if dec.err != nil {
+ if dec.err == os.EOF {
+ dec.err = io.ErrUnexpectedEOF
+ }
+ return
+ }
+ dec.buf.Write(dec.tmp)
+}
+
+// toInt turns an encoded uint64 into an int, according to the marshaling rules.
+func toInt(x uint64) int64 {
+ i := int64(x >> 1)
+ if x&1 != 0 {
+ i = ^i
+ }
+ return i
+}
+
+func (dec *Decoder) nextInt() int64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return toInt(n)
+}
+
+func (dec *Decoder) nextUint() uint64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return n
+}
+
+// decodeTypeSequence parses:
+// TypeSequence
+// (TypeDefinition DelimitedTypeDefinition*)?
+// and returns the type id of the next value. It returns -1 at
+// EOF. Upon return, the remainder of dec.buf is the value to be
+// decoded. If this is an interface value, it can be ignored by
+// simply resetting that buffer.
+func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
+ for dec.err == nil {
+ if dec.buf.Len() == 0 {
+ if !dec.recvMessage() {
+ break
+ }
+ }
+ // Receive a type id.
+ id := typeId(dec.nextInt())
+ if id >= 0 {
+ // Value follows.
+ return id
+ }
+ // Type definition for (-id) follows.
+ dec.recvType(-id)
+ // When decoding an interface, after a type there may be a
+ // DelimitedValue still in the buffer. Skip its count.
+ // (Alternatively, the buffer is empty and the byte count
+ // will be absorbed by recvMessage.)
+ if dec.buf.Len() > 0 {
+ if !isInterface {
+ dec.err = os.ErrorString("extra data in buffer")
+ break
+ }
+ dec.nextUint()
+ }
+ }
+ return -1
+}
+
+// Decode reads the next value from the connection and stores
+// it in the data represented by the empty interface value.
+// If e is nil, the value will be discarded. Otherwise,
+// the value underlying e must either be the correct type for the next
+// data item received, and must be a pointer.
+func (dec *Decoder) Decode(e interface{}) os.Error {
+ if e == nil {
+ return dec.DecodeValue(nil)
+ }
+ value := reflect.NewValue(e)
+ // If e represents a value as opposed to a pointer, the answer won't
+ // get back to the caller. Make sure it's a pointer.
+ if value.Type().Kind() != reflect.Ptr {
+ dec.err = os.ErrorString("gob: attempt to decode into a non-pointer")
+ return dec.err
+ }
+ return dec.DecodeValue(value)
+}
+
+// DecodeValue reads the next value from the connection and stores
+// it in the data represented by the reflection value.
+// The value must be the correct type for the next
+// data item received, or it may be nil, which means the
+// value will be discarded.
+func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
+ // Make sure we're single-threaded through here.
+ dec.mutex.Lock()
+ defer dec.mutex.Unlock()
+
+ dec.buf.Reset() // In case data lingers from previous invocation.
+ dec.err = nil
+ id := dec.decodeTypeSequence(false)
+ if dec.err == nil {
+ dec.decodeValue(id, value)
+ }
+ return dec.err
+}
+
+// If debug.go is compiled into the program , debugFunc prints a human-readable
+// representation of the gob data read from r by calling that file's Debug function.
+// Otherwise it is nil.
+var debugFunc func(io.Reader)
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "bufio"
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "sync"
+)
+
+// A Decoder manages the receipt of type and data information read from the
+// remote side of a connection.
+type Decoder struct {
+ mutex sync.Mutex // each item must be received atomically
+ r io.Reader // source of the data
+ buf bytes.Buffer // buffer for more efficient i/o from r
+ wireType map[typeId]*wireType // map from remote ID to local description
+ decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
+ ignorerCache map[typeId]**decEngine // ditto for ignored objects
+ freeList *decoderState // list of free decoderStates; avoids reallocation
+ countBuf []byte // used for decoding integers while parsing messages
+ tmp []byte // temporary storage for i/o; saves reallocating
+ err os.Error
+}
+
+// NewDecoder returns a new decoder that reads from the io.Reader.
+func NewDecoder(r io.Reader) *Decoder {
+ dec := new(Decoder)
+ dec.r = bufio.NewReader(r)
+ dec.wireType = make(map[typeId]*wireType)
+ dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
+ dec.ignorerCache = make(map[typeId]**decEngine)
+ dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes
+
+ return dec
+}
+
+// recvType loads the definition of a type.
+func (dec *Decoder) recvType(id typeId) {
+ // Have we already seen this type? That's an error
+ if id < firstUserId || dec.wireType[id] != nil {
+ dec.err = os.ErrorString("gob: duplicate type received")
+ return
+ }
+
+ // Type:
+ wire := new(wireType)
+ dec.decodeValue(tWireType, reflect.NewValue(wire))
+ if dec.err != nil {
+ return
+ }
+ // Remember we've seen this type.
+ dec.wireType[id] = wire
+}
+
+// recvMessage reads the next count-delimited item from the input. It is the converse
+// of Encoder.writeMessage. It returns false on EOF or other error reading the message.
+func (dec *Decoder) recvMessage() bool {
+ // Read a count.
+ nbytes, _, err := decodeUintReader(dec.r, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ return false
+ }
+ dec.readMessage(int(nbytes))
+ return dec.err == nil
+}
+
+// readMessage reads the next nbytes bytes from the input.
+func (dec *Decoder) readMessage(nbytes int) {
+ // Allocate the buffer.
+ if cap(dec.tmp) < nbytes {
+ dec.tmp = make([]byte, nbytes+100) // room to grow
+ }
+ dec.tmp = dec.tmp[:nbytes]
+
+ // Read the data
+ _, dec.err = io.ReadFull(dec.r, dec.tmp)
+ if dec.err != nil {
+ if dec.err == os.EOF {
+ dec.err = io.ErrUnexpectedEOF
+ }
+ return
+ }
+ dec.buf.Write(dec.tmp)
+}
+
+// toInt turns an encoded uint64 into an int, according to the marshaling rules.
+func toInt(x uint64) int64 {
+ i := int64(x >> 1)
+ if x&1 != 0 {
+ i = ^i
+ }
+ return i
+}
+
+func (dec *Decoder) nextInt() int64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return toInt(n)
+}
+
+func (dec *Decoder) nextUint() uint64 {
+ n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
+ if err != nil {
+ dec.err = err
+ }
+ return n
+}
+
+// decodeTypeSequence parses:
+// TypeSequence
+// (TypeDefinition DelimitedTypeDefinition*)?
+// and returns the type id of the next value. It returns -1 at
+// EOF. Upon return, the remainder of dec.buf is the value to be
+// decoded. If this is an interface value, it can be ignored by
+// simply resetting that buffer.
+func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
+ for dec.err == nil {
+ if dec.buf.Len() == 0 {
+ if !dec.recvMessage() {
+ break
+ }
+ }
+ // Receive a type id.
+ id := typeId(dec.nextInt())
+ if id >= 0 {
+ // Value follows.
+ return id
+ }
+ // Type definition for (-id) follows.
+ dec.recvType(-id)
+ // When decoding an interface, after a type there may be a
+ // DelimitedValue still in the buffer. Skip its count.
+ // (Alternatively, the buffer is empty and the byte count
+ // will be absorbed by recvMessage.)
+ if dec.buf.Len() > 0 {
+ if !isInterface {
+ dec.err = os.ErrorString("extra data in buffer")
+ break
+ }
+ dec.nextUint()
+ }
+ }
+ return -1
+}
+
+// Decode reads the next value from the connection and stores
+// it in the data represented by the empty interface value.
+// If e is nil, the value will be discarded. Otherwise,
+// the value underlying e must either be the correct type for the next
+// data item received, and must be a pointer.
+func (dec *Decoder) Decode(e interface{}) os.Error {
+ if e == nil {
+ return dec.DecodeValue(reflect.Value{})
+ }
+ value := reflect.NewValue(e)
+ // If e represents a value as opposed to a pointer, the answer won't
+ // get back to the caller. Make sure it's a pointer.
+ if value.Type().Kind() != reflect.Ptr {
+ dec.err = os.ErrorString("gob: attempt to decode into a non-pointer")
+ return dec.err
+ }
+ return dec.DecodeValue(value)
+}
+
+// DecodeValue reads the next value from the connection and stores
+// it in the data represented by the reflection value.
+// The value must be the correct type for the next
+// data item received, or it may be nil, which means the
+// value will be discarded.
+func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
+ // Make sure we're single-threaded through here.
+ dec.mutex.Lock()
+ defer dec.mutex.Unlock()
+
+ dec.buf.Reset() // In case data lingers from previous invocation.
+ dec.err = nil
+ id := dec.decodeTypeSequence(false)
+ if dec.err == nil {
+ dec.decodeValue(id, value)
+ }
+ return dec.err
+}
+
+// If debug.go is compiled into the program , debugFunc prints a human-readable
+// representation of the gob data read from r by calling that file's Debug function.
+// Otherwise it is nil.
+var debugFunc func(io.Reader)
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// DNS packet assembly. See RFC 1035.
+//
+// This is intended to support name resolution during net.Dial.
+// It doesn't have to be blazing fast.
+//
+// Rather than write the usual handful of routines to pack and
+// unpack every message that can appear on the wire, we use
+// reflection to write a generic pack/unpack for structs and then
+// use it. Thus, if in the future we need to define new message
+// structs, no new pack/unpack/printing code needs to be written.
+//
+// The first half of this file defines the DNS message formats.
+// The second half implements the conversion to and from wire format.
+// A few of the structure elements have string tags to aid the
+// generic pack/unpack routines.
+//
+// TODO(rsc): There are enough names defined in this file that they're all
+// prefixed with dns. Perhaps put this in its own package later.
+
+package net
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+)
+
+// Packet formats
+
+// Wire constants.
+const (
+ // valid dnsRR_Header.Rrtype and dnsQuestion.qtype
+ dnsTypeA = 1
+ dnsTypeNS = 2
+ dnsTypeMD = 3
+ dnsTypeMF = 4
+ dnsTypeCNAME = 5
+ dnsTypeSOA = 6
+ dnsTypeMB = 7
+ dnsTypeMG = 8
+ dnsTypeMR = 9
+ dnsTypeNULL = 10
+ dnsTypeWKS = 11
+ dnsTypePTR = 12
+ dnsTypeHINFO = 13
+ dnsTypeMINFO = 14
+ dnsTypeMX = 15
+ dnsTypeTXT = 16
+ dnsTypeAAAA = 28
+ dnsTypeSRV = 33
+
+ // valid dnsQuestion.qtype only
+ dnsTypeAXFR = 252
+ dnsTypeMAILB = 253
+ dnsTypeMAILA = 254
+ dnsTypeALL = 255
+
+ // valid dnsQuestion.qclass
+ dnsClassINET = 1
+ dnsClassCSNET = 2
+ dnsClassCHAOS = 3
+ dnsClassHESIOD = 4
+ dnsClassANY = 255
+
+ // dnsMsg.rcode
+ dnsRcodeSuccess = 0
+ dnsRcodeFormatError = 1
+ dnsRcodeServerFailure = 2
+ dnsRcodeNameError = 3
+ dnsRcodeNotImplemented = 4
+ dnsRcodeRefused = 5
+)
+
+// The wire format for the DNS packet header.
+type dnsHeader struct {
+ Id uint16
+ Bits uint16
+ Qdcount, Ancount, Nscount, Arcount uint16
+}
+
+const (
+ // dnsHeader.Bits
+ _QR = 1 << 15 // query/response (response=1)
+ _AA = 1 << 10 // authoritative
+ _TC = 1 << 9 // truncated
+ _RD = 1 << 8 // recursion desired
+ _RA = 1 << 7 // recursion available
+)
+
+// DNS queries.
+type dnsQuestion struct {
+ Name string "domain-name" // "domain-name" specifies encoding; see packers below
+ Qtype uint16
+ Qclass uint16
+}
+
+// DNS responses (resource records).
+// There are many types of messages,
+// but they all share the same header.
+type dnsRR_Header struct {
+ Name string "domain-name"
+ Rrtype uint16
+ Class uint16
+ Ttl uint32
+ Rdlength uint16 // length of data after header
+}
+
+func (h *dnsRR_Header) Header() *dnsRR_Header {
+ return h
+}
+
+type dnsRR interface {
+ Header() *dnsRR_Header
+}
+
+
+// Specific DNS RR formats for each query type.
+
+type dnsRR_CNAME struct {
+ Hdr dnsRR_Header
+ Cname string "domain-name"
+}
+
+func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_HINFO struct {
+ Hdr dnsRR_Header
+ Cpu string
+ Os string
+}
+
+func (rr *dnsRR_HINFO) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MB struct {
+ Hdr dnsRR_Header
+ Mb string "domain-name"
+}
+
+func (rr *dnsRR_MB) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MG struct {
+ Hdr dnsRR_Header
+ Mg string "domain-name"
+}
+
+func (rr *dnsRR_MG) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MINFO struct {
+ Hdr dnsRR_Header
+ Rmail string "domain-name"
+ Email string "domain-name"
+}
+
+func (rr *dnsRR_MINFO) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MR struct {
+ Hdr dnsRR_Header
+ Mr string "domain-name"
+}
+
+func (rr *dnsRR_MR) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MX struct {
+ Hdr dnsRR_Header
+ Pref uint16
+ Mx string "domain-name"
+}
+
+func (rr *dnsRR_MX) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_NS struct {
+ Hdr dnsRR_Header
+ Ns string "domain-name"
+}
+
+func (rr *dnsRR_NS) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_PTR struct {
+ Hdr dnsRR_Header
+ Ptr string "domain-name"
+}
+
+func (rr *dnsRR_PTR) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_SOA struct {
+ Hdr dnsRR_Header
+ Ns string "domain-name"
+ Mbox string "domain-name"
+ Serial uint32
+ Refresh uint32
+ Retry uint32
+ Expire uint32
+ Minttl uint32
+}
+
+func (rr *dnsRR_SOA) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_TXT struct {
+ Hdr dnsRR_Header
+ Txt string // not domain name
+}
+
+func (rr *dnsRR_TXT) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_SRV struct {
+ Hdr dnsRR_Header
+ Priority uint16
+ Weight uint16
+ Port uint16
+ Target string "domain-name"
+}
+
+func (rr *dnsRR_SRV) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_A struct {
+ Hdr dnsRR_Header
+ A uint32 "ipv4"
+}
+
+func (rr *dnsRR_A) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_AAAA struct {
+ Hdr dnsRR_Header
+ AAAA [16]byte "ipv6"
+}
+
+func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+// Packing and unpacking.
+//
+// All the packers and unpackers take a (msg []byte, off int)
+// and return (off1 int, ok bool). If they return ok==false, they
+// also return off1==len(msg), so that the next unpacker will
+// also fail. This lets us avoid checks of ok until the end of a
+// packing sequence.
+
+// Map of constructors for each RR wire type.
+var rr_mk = map[int]func() dnsRR{
+ dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) },
+ dnsTypeHINFO: func() dnsRR { return new(dnsRR_HINFO) },
+ dnsTypeMB: func() dnsRR { return new(dnsRR_MB) },
+ dnsTypeMG: func() dnsRR { return new(dnsRR_MG) },
+ dnsTypeMINFO: func() dnsRR { return new(dnsRR_MINFO) },
+ dnsTypeMR: func() dnsRR { return new(dnsRR_MR) },
+ dnsTypeMX: func() dnsRR { return new(dnsRR_MX) },
+ dnsTypeNS: func() dnsRR { return new(dnsRR_NS) },
+ dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) },
+ dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) },
+ dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) },
+ dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) },
+ dnsTypeA: func() dnsRR { return new(dnsRR_A) },
+ dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) },
+}
+
+// Pack a domain name s into msg[off:].
+// Domain names are a sequence of counted strings
+// split at the dots. They end with a zero-length string.
+func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
+ // Add trailing dot to canonicalize name.
+ if n := len(s); n == 0 || s[n-1] != '.' {
+ s += "."
+ }
+
+ // Each dot ends a segment of the name.
+ // We trade each dot byte for a length byte.
+ // There is also a trailing zero.
+ // Check that we have all the space we need.
+ tot := len(s) + 1
+ if off+tot > len(msg) {
+ return len(msg), false
+ }
+
+ // Emit sequence of counted strings, chopping at dots.
+ begin := 0
+ for i := 0; i < len(s); i++ {
+ if s[i] == '.' {
+ if i-begin >= 1<<6 { // top two bits of length must be clear
+ return len(msg), false
+ }
+ msg[off] = byte(i - begin)
+ off++
+ for j := begin; j < i; j++ {
+ msg[off] = s[j]
+ off++
+ }
+ begin = i + 1
+ }
+ }
+ msg[off] = 0
+ off++
+ return off, true
+}
+
+// Unpack a domain name.
+// In addition to the simple sequences of counted strings above,
+// domain names are allowed to refer to strings elsewhere in the
+// packet, to avoid repeating common suffixes when returning
+// many entries in a single domain. The pointers are marked
+// by a length byte with the top two bits set. Ignoring those
+// two bits, that byte and the next give a 14 bit offset from msg[0]
+// where we should pick up the trail.
+// Note that if we jump elsewhere in the packet,
+// we return off1 == the offset after the first pointer we found,
+// which is where the next record will start.
+// In theory, the pointers are only allowed to jump backward.
+// We let them jump anywhere and stop jumping after a while.
+func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
+ s = ""
+ ptr := 0 // number of pointers followed
+Loop:
+ for {
+ if off >= len(msg) {
+ return "", len(msg), false
+ }
+ c := int(msg[off])
+ off++
+ switch c & 0xC0 {
+ case 0x00:
+ if c == 0x00 {
+ // end of name
+ break Loop
+ }
+ // literal string
+ if off+c > len(msg) {
+ return "", len(msg), false
+ }
+ s += string(msg[off:off+c]) + "."
+ off += c
+ case 0xC0:
+ // pointer to somewhere else in msg.
+ // remember location after first ptr,
+ // since that's how many bytes we consumed.
+ // also, don't follow too many pointers --
+ // maybe there's a loop.
+ if off >= len(msg) {
+ return "", len(msg), false
+ }
+ c1 := msg[off]
+ off++
+ if ptr == 0 {
+ off1 = off
+ }
+ if ptr++; ptr > 10 {
+ return "", len(msg), false
+ }
+ off = (c^0xC0)<<8 | int(c1)
+ default:
+ // 0x80 and 0x40 are reserved
+ return "", len(msg), false
+ }
+ }
+ if ptr == 0 {
+ off1 = off
+ }
+ return s, off1, true
+}
+
+// TODO(rsc): Move into generic library?
+// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string,
+// [n]byte, and other (often anonymous) structs.
+func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) {
+ for i := 0; i < val.NumField(); i++ {
+ f := val.Type().(*reflect.StructType).Field(i)
+ switch fv := val.Field(i).(type) {
+ default:
+ BadType:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
+ return len(msg), false
+ case *reflect.StructValue:
+ off, ok = packStructValue(fv, msg, off)
+ case *reflect.UintValue:
+ i := fv.Get()
+ switch fv.Type().Kind() {
+ default:
+ goto BadType
+ case reflect.Uint16:
+ if off+2 > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(i >> 8)
+ msg[off+1] = byte(i)
+ off += 2
+ case reflect.Uint32:
+ if off+4 > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(i >> 24)
+ msg[off+1] = byte(i >> 16)
+ msg[off+2] = byte(i >> 8)
+ msg[off+3] = byte(i)
+ off += 4
+ }
+ case *reflect.ArrayValue:
+ if fv.Type().(*reflect.ArrayType).Elem().Kind() != reflect.Uint8 {
+ goto BadType
+ }
+ n := fv.Len()
+ if off+n > len(msg) {
+ return len(msg), false
+ }
+ reflect.Copy(reflect.NewValue(msg[off:off+n]).(*reflect.SliceValue), fv)
+ off += n
+ case *reflect.StringValue:
+ // There are multiple string encodings.
+ // The tag distinguishes ordinary strings from domain names.
+ s := fv.Get()
+ switch f.Tag {
+ default:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
+ return len(msg), false
+ case "domain-name":
+ off, ok = packDomainName(s, msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ case "":
+ // Counted string: 1 byte length.
+ if len(s) > 255 || off+1+len(s) > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(len(s))
+ off++
+ off += copy(msg[off:], s)
+ }
+ }
+ }
+ return off, true
+}
+
+func structValue(any interface{}) *reflect.StructValue {
+ return reflect.NewValue(any).(*reflect.PtrValue).Elem().(*reflect.StructValue)
+}
+
+func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
+ off, ok = packStructValue(structValue(any), msg, off)
+ return off, ok
+}
+
+// TODO(rsc): Move into generic library?
+// Unpack a reflect.StructValue from msg.
+// Same restrictions as packStructValue.
+func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) {
+ for i := 0; i < val.NumField(); i++ {
+ f := val.Type().(*reflect.StructType).Field(i)
+ switch fv := val.Field(i).(type) {
+ default:
+ BadType:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
+ return len(msg), false
+ case *reflect.StructValue:
+ off, ok = unpackStructValue(fv, msg, off)
+ case *reflect.UintValue:
+ switch fv.Type().Kind() {
+ default:
+ goto BadType
+ case reflect.Uint16:
+ if off+2 > len(msg) {
+ return len(msg), false
+ }
+ i := uint16(msg[off])<<8 | uint16(msg[off+1])
+ fv.Set(uint64(i))
+ off += 2
+ case reflect.Uint32:
+ if off+4 > len(msg) {
+ return len(msg), false
+ }
+ i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
+ fv.Set(uint64(i))
+ off += 4
+ }
+ case *reflect.ArrayValue:
+ if fv.Type().(*reflect.ArrayType).Elem().Kind() != reflect.Uint8 {
+ goto BadType
+ }
+ n := fv.Len()
+ if off+n > len(msg) {
+ return len(msg), false
+ }
+ reflect.Copy(fv, reflect.NewValue(msg[off:off+n]).(*reflect.SliceValue))
+ off += n
+ case *reflect.StringValue:
+ var s string
+ switch f.Tag {
+ default:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
+ return len(msg), false
+ case "domain-name":
+ s, off, ok = unpackDomainName(msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ case "":
+ if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
+ return len(msg), false
+ }
+ n := int(msg[off])
+ off++
+ b := make([]byte, n)
+ for i := 0; i < n; i++ {
+ b[i] = msg[off+i]
+ }
+ off += n
+ s = string(b)
+ }
+ fv.Set(s)
+ }
+ }
+ return off, true
+}
+
+func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
+ off, ok = unpackStructValue(structValue(any), msg, off)
+ return off, ok
+}
+
+// Generic struct printer.
+// Doesn't care about the string tag "domain-name",
+// but does look for an "ipv4" tag on uint32 variables
+// and the "ipv6" tag on array variables,
+// printing them as IP addresses.
+func printStructValue(val *reflect.StructValue) string {
+ s := "{"
+ for i := 0; i < val.NumField(); i++ {
+ if i > 0 {
+ s += ", "
+ }
+ f := val.Type().(*reflect.StructType).Field(i)
+ if !f.Anonymous {
+ s += f.Name + "="
+ }
+ fval := val.Field(i)
+ if fv, ok := fval.(*reflect.StructValue); ok {
+ s += printStructValue(fv)
+ } else if fv, ok := fval.(*reflect.UintValue); ok && f.Tag == "ipv4" {
+ i := fv.Get()
+ s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
+ } else if fv, ok := fval.(*reflect.ArrayValue); ok && f.Tag == "ipv6" {
+ i := fv.Interface().([]byte)
+ s += IP(i).String()
+ } else {
+ s += fmt.Sprint(fval.Interface())
+ }
+ }
+ s += "}"
+ return s
+}
+
+func printStruct(any interface{}) string { return printStructValue(structValue(any)) }
+
+// Resource record packer.
+func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
+ var off1 int
+ // pack twice, once to find end of header
+ // and again to find end of packet.
+ // a bit inefficient but this doesn't need to be fast.
+ // off1 is end of header
+ // off2 is end of rr
+ off1, ok = packStruct(rr.Header(), msg, off)
+ off2, ok = packStruct(rr, msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ // pack a third time; redo header with correct data length
+ rr.Header().Rdlength = uint16(off2 - off1)
+ packStruct(rr.Header(), msg, off)
+ return off2, true
+}
+
+// Resource record unpacker.
+func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) {
+ // unpack just the header, to find the rr type and length
+ var h dnsRR_Header
+ off0 := off
+ if off, ok = unpackStruct(&h, msg, off); !ok {
+ return nil, len(msg), false
+ }
+ end := off + int(h.Rdlength)
+
+ // make an rr of that type and re-unpack.
+ // again inefficient but doesn't need to be fast.
+ mk, known := rr_mk[int(h.Rrtype)]
+ if !known {
+ return &h, end, true
+ }
+ rr = mk()
+ off, ok = unpackStruct(rr, msg, off0)
+ if off != end {
+ return &h, end, true
+ }
+ return rr, off, ok
+}
+
+// Usable representation of a DNS packet.
+
+// A manually-unpacked version of (id, bits).
+// This is in its own struct for easy printing.
+type dnsMsgHdr struct {
+ id uint16
+ response bool
+ opcode int
+ authoritative bool
+ truncated bool
+ recursion_desired bool
+ recursion_available bool
+ rcode int
+}
+
+type dnsMsg struct {
+ dnsMsgHdr
+ question []dnsQuestion
+ answer []dnsRR
+ ns []dnsRR
+ extra []dnsRR
+}
+
+
+func (dns *dnsMsg) Pack() (msg []byte, ok bool) {
+ var dh dnsHeader
+
+ // Convert convenient dnsMsg into wire-like dnsHeader.
+ dh.Id = dns.id
+ dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode)
+ if dns.recursion_available {
+ dh.Bits |= _RA
+ }
+ if dns.recursion_desired {
+ dh.Bits |= _RD
+ }
+ if dns.truncated {
+ dh.Bits |= _TC
+ }
+ if dns.authoritative {
+ dh.Bits |= _AA
+ }
+ if dns.response {
+ dh.Bits |= _QR
+ }
+
+ // Prepare variable sized arrays.
+ question := dns.question
+ answer := dns.answer
+ ns := dns.ns
+ extra := dns.extra
+
+ dh.Qdcount = uint16(len(question))
+ dh.Ancount = uint16(len(answer))
+ dh.Nscount = uint16(len(ns))
+ dh.Arcount = uint16(len(extra))
+
+ // Could work harder to calculate message size,
+ // but this is far more than we need and not
+ // big enough to hurt the allocator.
+ msg = make([]byte, 2000)
+
+ // Pack it in: header and then the pieces.
+ off := 0
+ off, ok = packStruct(&dh, msg, off)
+ for i := 0; i < len(question); i++ {
+ off, ok = packStruct(&question[i], msg, off)
+ }
+ for i := 0; i < len(answer); i++ {
+ off, ok = packRR(answer[i], msg, off)
+ }
+ for i := 0; i < len(ns); i++ {
+ off, ok = packRR(ns[i], msg, off)
+ }
+ for i := 0; i < len(extra); i++ {
+ off, ok = packRR(extra[i], msg, off)
+ }
+ if !ok {
+ return nil, false
+ }
+ return msg[0:off], true
+}
+
+func (dns *dnsMsg) Unpack(msg []byte) bool {
+ // Header.
+ var dh dnsHeader
+ off := 0
+ var ok bool
+ if off, ok = unpackStruct(&dh, msg, off); !ok {
+ return false
+ }
+ dns.id = dh.Id
+ dns.response = (dh.Bits & _QR) != 0
+ dns.opcode = int(dh.Bits>>11) & 0xF
+ dns.authoritative = (dh.Bits & _AA) != 0
+ dns.truncated = (dh.Bits & _TC) != 0
+ dns.recursion_desired = (dh.Bits & _RD) != 0
+ dns.recursion_available = (dh.Bits & _RA) != 0
+ dns.rcode = int(dh.Bits & 0xF)
+
+ // Arrays.
+ dns.question = make([]dnsQuestion, dh.Qdcount)
+ dns.answer = make([]dnsRR, dh.Ancount)
+ dns.ns = make([]dnsRR, dh.Nscount)
+ dns.extra = make([]dnsRR, dh.Arcount)
+
+ for i := 0; i < len(dns.question); i++ {
+ off, ok = unpackStruct(&dns.question[i], msg, off)
+ }
+ for i := 0; i < len(dns.answer); i++ {
+ dns.answer[i], off, ok = unpackRR(msg, off)
+ }
+ for i := 0; i < len(dns.ns); i++ {
+ dns.ns[i], off, ok = unpackRR(msg, off)
+ }
+ for i := 0; i < len(dns.extra); i++ {
+ dns.extra[i], off, ok = unpackRR(msg, off)
+ }
+ if !ok {
+ return false
+ }
+ // if off != len(msg) {
+ // println("extra bytes in dns packet", off, "<", len(msg));
+ // }
+ return true
+}
+
+func (dns *dnsMsg) String() string {
+ s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n"
+ if len(dns.question) > 0 {
+ s += "-- Questions\n"
+ for i := 0; i < len(dns.question); i++ {
+ s += printStruct(&dns.question[i]) + "\n"
+ }
+ }
+ if len(dns.answer) > 0 {
+ s += "-- Answers\n"
+ for i := 0; i < len(dns.answer); i++ {
+ s += printStruct(dns.answer[i]) + "\n"
+ }
+ }
+ if len(dns.ns) > 0 {
+ s += "-- Name servers\n"
+ for i := 0; i < len(dns.ns); i++ {
+ s += printStruct(dns.ns[i]) + "\n"
+ }
+ }
+ if len(dns.extra) > 0 {
+ s += "-- Extra\n"
+ for i := 0; i < len(dns.extra); i++ {
+ s += printStruct(dns.extra[i]) + "\n"
+ }
+ }
+ return s
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// DNS packet assembly. See RFC 1035.
+//
+// This is intended to support name resolution during net.Dial.
+// It doesn't have to be blazing fast.
+//
+// Rather than write the usual handful of routines to pack and
+// unpack every message that can appear on the wire, we use
+// reflection to write a generic pack/unpack for structs and then
+// use it. Thus, if in the future we need to define new message
+// structs, no new pack/unpack/printing code needs to be written.
+//
+// The first half of this file defines the DNS message formats.
+// The second half implements the conversion to and from wire format.
+// A few of the structure elements have string tags to aid the
+// generic pack/unpack routines.
+//
+// TODO(rsc): There are enough names defined in this file that they're all
+// prefixed with dns. Perhaps put this in its own package later.
+
+package net
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+)
+
+// Packet formats
+
+// Wire constants.
+const (
+ // valid dnsRR_Header.Rrtype and dnsQuestion.qtype
+ dnsTypeA = 1
+ dnsTypeNS = 2
+ dnsTypeMD = 3
+ dnsTypeMF = 4
+ dnsTypeCNAME = 5
+ dnsTypeSOA = 6
+ dnsTypeMB = 7
+ dnsTypeMG = 8
+ dnsTypeMR = 9
+ dnsTypeNULL = 10
+ dnsTypeWKS = 11
+ dnsTypePTR = 12
+ dnsTypeHINFO = 13
+ dnsTypeMINFO = 14
+ dnsTypeMX = 15
+ dnsTypeTXT = 16
+ dnsTypeAAAA = 28
+ dnsTypeSRV = 33
+
+ // valid dnsQuestion.qtype only
+ dnsTypeAXFR = 252
+ dnsTypeMAILB = 253
+ dnsTypeMAILA = 254
+ dnsTypeALL = 255
+
+ // valid dnsQuestion.qclass
+ dnsClassINET = 1
+ dnsClassCSNET = 2
+ dnsClassCHAOS = 3
+ dnsClassHESIOD = 4
+ dnsClassANY = 255
+
+ // dnsMsg.rcode
+ dnsRcodeSuccess = 0
+ dnsRcodeFormatError = 1
+ dnsRcodeServerFailure = 2
+ dnsRcodeNameError = 3
+ dnsRcodeNotImplemented = 4
+ dnsRcodeRefused = 5
+)
+
+// The wire format for the DNS packet header.
+type dnsHeader struct {
+ Id uint16
+ Bits uint16
+ Qdcount, Ancount, Nscount, Arcount uint16
+}
+
+const (
+ // dnsHeader.Bits
+ _QR = 1 << 15 // query/response (response=1)
+ _AA = 1 << 10 // authoritative
+ _TC = 1 << 9 // truncated
+ _RD = 1 << 8 // recursion desired
+ _RA = 1 << 7 // recursion available
+)
+
+// DNS queries.
+type dnsQuestion struct {
+ Name string "domain-name" // "domain-name" specifies encoding; see packers below
+ Qtype uint16
+ Qclass uint16
+}
+
+// DNS responses (resource records).
+// There are many types of messages,
+// but they all share the same header.
+type dnsRR_Header struct {
+ Name string "domain-name"
+ Rrtype uint16
+ Class uint16
+ Ttl uint32
+ Rdlength uint16 // length of data after header
+}
+
+func (h *dnsRR_Header) Header() *dnsRR_Header {
+ return h
+}
+
+type dnsRR interface {
+ Header() *dnsRR_Header
+}
+
+
+// Specific DNS RR formats for each query type.
+
+type dnsRR_CNAME struct {
+ Hdr dnsRR_Header
+ Cname string "domain-name"
+}
+
+func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_HINFO struct {
+ Hdr dnsRR_Header
+ Cpu string
+ Os string
+}
+
+func (rr *dnsRR_HINFO) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MB struct {
+ Hdr dnsRR_Header
+ Mb string "domain-name"
+}
+
+func (rr *dnsRR_MB) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MG struct {
+ Hdr dnsRR_Header
+ Mg string "domain-name"
+}
+
+func (rr *dnsRR_MG) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MINFO struct {
+ Hdr dnsRR_Header
+ Rmail string "domain-name"
+ Email string "domain-name"
+}
+
+func (rr *dnsRR_MINFO) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MR struct {
+ Hdr dnsRR_Header
+ Mr string "domain-name"
+}
+
+func (rr *dnsRR_MR) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_MX struct {
+ Hdr dnsRR_Header
+ Pref uint16
+ Mx string "domain-name"
+}
+
+func (rr *dnsRR_MX) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_NS struct {
+ Hdr dnsRR_Header
+ Ns string "domain-name"
+}
+
+func (rr *dnsRR_NS) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_PTR struct {
+ Hdr dnsRR_Header
+ Ptr string "domain-name"
+}
+
+func (rr *dnsRR_PTR) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_SOA struct {
+ Hdr dnsRR_Header
+ Ns string "domain-name"
+ Mbox string "domain-name"
+ Serial uint32
+ Refresh uint32
+ Retry uint32
+ Expire uint32
+ Minttl uint32
+}
+
+func (rr *dnsRR_SOA) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_TXT struct {
+ Hdr dnsRR_Header
+ Txt string // not domain name
+}
+
+func (rr *dnsRR_TXT) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_SRV struct {
+ Hdr dnsRR_Header
+ Priority uint16
+ Weight uint16
+ Port uint16
+ Target string "domain-name"
+}
+
+func (rr *dnsRR_SRV) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_A struct {
+ Hdr dnsRR_Header
+ A uint32 "ipv4"
+}
+
+func (rr *dnsRR_A) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+type dnsRR_AAAA struct {
+ Hdr dnsRR_Header
+ AAAA [16]byte "ipv6"
+}
+
+func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
+ return &rr.Hdr
+}
+
+// Packing and unpacking.
+//
+// All the packers and unpackers take a (msg []byte, off int)
+// and return (off1 int, ok bool). If they return ok==false, they
+// also return off1==len(msg), so that the next unpacker will
+// also fail. This lets us avoid checks of ok until the end of a
+// packing sequence.
+
+// Map of constructors for each RR wire type.
+var rr_mk = map[int]func() dnsRR{
+ dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) },
+ dnsTypeHINFO: func() dnsRR { return new(dnsRR_HINFO) },
+ dnsTypeMB: func() dnsRR { return new(dnsRR_MB) },
+ dnsTypeMG: func() dnsRR { return new(dnsRR_MG) },
+ dnsTypeMINFO: func() dnsRR { return new(dnsRR_MINFO) },
+ dnsTypeMR: func() dnsRR { return new(dnsRR_MR) },
+ dnsTypeMX: func() dnsRR { return new(dnsRR_MX) },
+ dnsTypeNS: func() dnsRR { return new(dnsRR_NS) },
+ dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) },
+ dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) },
+ dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) },
+ dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) },
+ dnsTypeA: func() dnsRR { return new(dnsRR_A) },
+ dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) },
+}
+
+// Pack a domain name s into msg[off:].
+// Domain names are a sequence of counted strings
+// split at the dots. They end with a zero-length string.
+func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
+ // Add trailing dot to canonicalize name.
+ if n := len(s); n == 0 || s[n-1] != '.' {
+ s += "."
+ }
+
+ // Each dot ends a segment of the name.
+ // We trade each dot byte for a length byte.
+ // There is also a trailing zero.
+ // Check that we have all the space we need.
+ tot := len(s) + 1
+ if off+tot > len(msg) {
+ return len(msg), false
+ }
+
+ // Emit sequence of counted strings, chopping at dots.
+ begin := 0
+ for i := 0; i < len(s); i++ {
+ if s[i] == '.' {
+ if i-begin >= 1<<6 { // top two bits of length must be clear
+ return len(msg), false
+ }
+ msg[off] = byte(i - begin)
+ off++
+ for j := begin; j < i; j++ {
+ msg[off] = s[j]
+ off++
+ }
+ begin = i + 1
+ }
+ }
+ msg[off] = 0
+ off++
+ return off, true
+}
+
+// Unpack a domain name.
+// In addition to the simple sequences of counted strings above,
+// domain names are allowed to refer to strings elsewhere in the
+// packet, to avoid repeating common suffixes when returning
+// many entries in a single domain. The pointers are marked
+// by a length byte with the top two bits set. Ignoring those
+// two bits, that byte and the next give a 14 bit offset from msg[0]
+// where we should pick up the trail.
+// Note that if we jump elsewhere in the packet,
+// we return off1 == the offset after the first pointer we found,
+// which is where the next record will start.
+// In theory, the pointers are only allowed to jump backward.
+// We let them jump anywhere and stop jumping after a while.
+func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
+ s = ""
+ ptr := 0 // number of pointers followed
+Loop:
+ for {
+ if off >= len(msg) {
+ return "", len(msg), false
+ }
+ c := int(msg[off])
+ off++
+ switch c & 0xC0 {
+ case 0x00:
+ if c == 0x00 {
+ // end of name
+ break Loop
+ }
+ // literal string
+ if off+c > len(msg) {
+ return "", len(msg), false
+ }
+ s += string(msg[off:off+c]) + "."
+ off += c
+ case 0xC0:
+ // pointer to somewhere else in msg.
+ // remember location after first ptr,
+ // since that's how many bytes we consumed.
+ // also, don't follow too many pointers --
+ // maybe there's a loop.
+ if off >= len(msg) {
+ return "", len(msg), false
+ }
+ c1 := msg[off]
+ off++
+ if ptr == 0 {
+ off1 = off
+ }
+ if ptr++; ptr > 10 {
+ return "", len(msg), false
+ }
+ off = (c^0xC0)<<8 | int(c1)
+ default:
+ // 0x80 and 0x40 are reserved
+ return "", len(msg), false
+ }
+ }
+ if ptr == 0 {
+ off1 = off
+ }
+ return s, off1, true
+}
+
+// TODO(rsc): Move into generic library?
+// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string,
+// [n]byte, and other (often anonymous) structs.
+func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
+ for i := 0; i < val.NumField(); i++ {
+ f := val.Type().Field(i)
+ switch fv := val.Field(i); fv.Kind() {
+ default:
+ BadType:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
+ return len(msg), false
+ case reflect.Struct:
+ off, ok = packStructValue(fv, msg, off)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ i := fv.Uint()
+ switch fv.Type().Kind() {
+ default:
+ goto BadType
+ case reflect.Uint16:
+ if off+2 > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(i >> 8)
+ msg[off+1] = byte(i)
+ off += 2
+ case reflect.Uint32:
+ if off+4 > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(i >> 24)
+ msg[off+1] = byte(i >> 16)
+ msg[off+2] = byte(i >> 8)
+ msg[off+3] = byte(i)
+ off += 4
+ }
+ case reflect.Array:
+ if fv.Type().Elem().Kind() != reflect.Uint8 {
+ goto BadType
+ }
+ n := fv.Len()
+ if off+n > len(msg) {
+ return len(msg), false
+ }
+ reflect.Copy(reflect.NewValue(msg[off:off+n]), fv)
+ off += n
+ case reflect.String:
+ // There are multiple string encodings.
+ // The tag distinguishes ordinary strings from domain names.
+ s := fv.String()
+ switch f.Tag {
+ default:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
+ return len(msg), false
+ case "domain-name":
+ off, ok = packDomainName(s, msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ case "":
+ // Counted string: 1 byte length.
+ if len(s) > 255 || off+1+len(s) > len(msg) {
+ return len(msg), false
+ }
+ msg[off] = byte(len(s))
+ off++
+ off += copy(msg[off:], s)
+ }
+ }
+ }
+ return off, true
+}
+
+func structValue(any interface{}) reflect.Value {
+ return reflect.NewValue(any).Elem()
+}
+
+func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
+ off, ok = packStructValue(structValue(any), msg, off)
+ return off, ok
+}
+
+// TODO(rsc): Move into generic library?
+// Unpack a reflect.StructValue from msg.
+// Same restrictions as packStructValue.
+func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
+ for i := 0; i < val.NumField(); i++ {
+ f := val.Type().Field(i)
+ switch fv := val.Field(i); fv.Kind() {
+ default:
+ BadType:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
+ return len(msg), false
+ case reflect.Struct:
+ off, ok = unpackStructValue(fv, msg, off)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ switch fv.Type().Kind() {
+ default:
+ goto BadType
+ case reflect.Uint16:
+ if off+2 > len(msg) {
+ return len(msg), false
+ }
+ i := uint16(msg[off])<<8 | uint16(msg[off+1])
+ fv.SetUint(uint64(i))
+ off += 2
+ case reflect.Uint32:
+ if off+4 > len(msg) {
+ return len(msg), false
+ }
+ i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
+ fv.SetUint(uint64(i))
+ off += 4
+ }
+ case reflect.Array:
+ if fv.Type().Elem().Kind() != reflect.Uint8 {
+ goto BadType
+ }
+ n := fv.Len()
+ if off+n > len(msg) {
+ return len(msg), false
+ }
+ reflect.Copy(fv, reflect.NewValue(msg[off:off+n]))
+ off += n
+ case reflect.String:
+ var s string
+ switch f.Tag {
+ default:
+ fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
+ return len(msg), false
+ case "domain-name":
+ s, off, ok = unpackDomainName(msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ case "":
+ if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
+ return len(msg), false
+ }
+ n := int(msg[off])
+ off++
+ b := make([]byte, n)
+ for i := 0; i < n; i++ {
+ b[i] = msg[off+i]
+ }
+ off += n
+ s = string(b)
+ }
+ fv.SetString(s)
+ }
+ }
+ return off, true
+}
+
+func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
+ off, ok = unpackStructValue(structValue(any), msg, off)
+ return off, ok
+}
+
+// Generic struct printer.
+// Doesn't care about the string tag "domain-name",
+// but does look for an "ipv4" tag on uint32 variables
+// and the "ipv6" tag on array variables,
+// printing them as IP addresses.
+func printStructValue(val reflect.Value) string {
+ s := "{"
+ for i := 0; i < val.NumField(); i++ {
+ if i > 0 {
+ s += ", "
+ }
+ f := val.Type().Field(i)
+ if !f.Anonymous {
+ s += f.Name + "="
+ }
+ fval := val.Field(i)
+ if fv := fval; fv.Kind() == reflect.Struct {
+ s += printStructValue(fv)
+ } else if fv := fval; (fv.Kind() == reflect.Uint || fv.Kind() == reflect.Uint8 || fv.Kind() == reflect.Uint16 || fv.Kind() == reflect.Uint32 || fv.Kind() == reflect.Uint64 || fv.Kind() == reflect.Uintptr) && f.Tag == "ipv4" {
+ i := fv.Uint()
+ s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
+ } else if fv := fval; fv.Kind() == reflect.Array && f.Tag == "ipv6" {
+ i := fv.Interface().([]byte)
+ s += IP(i).String()
+ } else {
+ s += fmt.Sprint(fval.Interface())
+ }
+ }
+ s += "}"
+ return s
+}
+
+func printStruct(any interface{}) string { return printStructValue(structValue(any)) }
+
+// Resource record packer.
+func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
+ var off1 int
+ // pack twice, once to find end of header
+ // and again to find end of packet.
+ // a bit inefficient but this doesn't need to be fast.
+ // off1 is end of header
+ // off2 is end of rr
+ off1, ok = packStruct(rr.Header(), msg, off)
+ off2, ok = packStruct(rr, msg, off)
+ if !ok {
+ return len(msg), false
+ }
+ // pack a third time; redo header with correct data length
+ rr.Header().Rdlength = uint16(off2 - off1)
+ packStruct(rr.Header(), msg, off)
+ return off2, true
+}
+
+// Resource record unpacker.
+func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) {
+ // unpack just the header, to find the rr type and length
+ var h dnsRR_Header
+ off0 := off
+ if off, ok = unpackStruct(&h, msg, off); !ok {
+ return nil, len(msg), false
+ }
+ end := off + int(h.Rdlength)
+
+ // make an rr of that type and re-unpack.
+ // again inefficient but doesn't need to be fast.
+ mk, known := rr_mk[int(h.Rrtype)]
+ if !known {
+ return &h, end, true
+ }
+ rr = mk()
+ off, ok = unpackStruct(rr, msg, off0)
+ if off != end {
+ return &h, end, true
+ }
+ return rr, off, ok
+}
+
+// Usable representation of a DNS packet.
+
+// A manually-unpacked version of (id, bits).
+// This is in its own struct for easy printing.
+type dnsMsgHdr struct {
+ id uint16
+ response bool
+ opcode int
+ authoritative bool
+ truncated bool
+ recursion_desired bool
+ recursion_available bool
+ rcode int
+}
+
+type dnsMsg struct {
+ dnsMsgHdr
+ question []dnsQuestion
+ answer []dnsRR
+ ns []dnsRR
+ extra []dnsRR
+}
+
+
+func (dns *dnsMsg) Pack() (msg []byte, ok bool) {
+ var dh dnsHeader
+
+ // Convert convenient dnsMsg into wire-like dnsHeader.
+ dh.Id = dns.id
+ dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode)
+ if dns.recursion_available {
+ dh.Bits |= _RA
+ }
+ if dns.recursion_desired {
+ dh.Bits |= _RD
+ }
+ if dns.truncated {
+ dh.Bits |= _TC
+ }
+ if dns.authoritative {
+ dh.Bits |= _AA
+ }
+ if dns.response {
+ dh.Bits |= _QR
+ }
+
+ // Prepare variable sized arrays.
+ question := dns.question
+ answer := dns.answer
+ ns := dns.ns
+ extra := dns.extra
+
+ dh.Qdcount = uint16(len(question))
+ dh.Ancount = uint16(len(answer))
+ dh.Nscount = uint16(len(ns))
+ dh.Arcount = uint16(len(extra))
+
+ // Could work harder to calculate message size,
+ // but this is far more than we need and not
+ // big enough to hurt the allocator.
+ msg = make([]byte, 2000)
+
+ // Pack it in: header and then the pieces.
+ off := 0
+ off, ok = packStruct(&dh, msg, off)
+ for i := 0; i < len(question); i++ {
+ off, ok = packStruct(&question[i], msg, off)
+ }
+ for i := 0; i < len(answer); i++ {
+ off, ok = packRR(answer[i], msg, off)
+ }
+ for i := 0; i < len(ns); i++ {
+ off, ok = packRR(ns[i], msg, off)
+ }
+ for i := 0; i < len(extra); i++ {
+ off, ok = packRR(extra[i], msg, off)
+ }
+ if !ok {
+ return nil, false
+ }
+ return msg[0:off], true
+}
+
+func (dns *dnsMsg) Unpack(msg []byte) bool {
+ // Header.
+ var dh dnsHeader
+ off := 0
+ var ok bool
+ if off, ok = unpackStruct(&dh, msg, off); !ok {
+ return false
+ }
+ dns.id = dh.Id
+ dns.response = (dh.Bits & _QR) != 0
+ dns.opcode = int(dh.Bits>>11) & 0xF
+ dns.authoritative = (dh.Bits & _AA) != 0
+ dns.truncated = (dh.Bits & _TC) != 0
+ dns.recursion_desired = (dh.Bits & _RD) != 0
+ dns.recursion_available = (dh.Bits & _RA) != 0
+ dns.rcode = int(dh.Bits & 0xF)
+
+ // Arrays.
+ dns.question = make([]dnsQuestion, dh.Qdcount)
+ dns.answer = make([]dnsRR, dh.Ancount)
+ dns.ns = make([]dnsRR, dh.Nscount)
+ dns.extra = make([]dnsRR, dh.Arcount)
+
+ for i := 0; i < len(dns.question); i++ {
+ off, ok = unpackStruct(&dns.question[i], msg, off)
+ }
+ for i := 0; i < len(dns.answer); i++ {
+ dns.answer[i], off, ok = unpackRR(msg, off)
+ }
+ for i := 0; i < len(dns.ns); i++ {
+ dns.ns[i], off, ok = unpackRR(msg, off)
+ }
+ for i := 0; i < len(dns.extra); i++ {
+ dns.extra[i], off, ok = unpackRR(msg, off)
+ }
+ if !ok {
+ return false
+ }
+ // if off != len(msg) {
+ // println("extra bytes in dns packet", off, "<", len(msg));
+ // }
+ return true
+}
+
+func (dns *dnsMsg) String() string {
+ s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n"
+ if len(dns.question) > 0 {
+ s += "-- Questions\n"
+ for i := 0; i < len(dns.question); i++ {
+ s += printStruct(&dns.question[i]) + "\n"
+ }
+ }
+ if len(dns.answer) > 0 {
+ s += "-- Answers\n"
+ for i := 0; i < len(dns.answer); i++ {
+ s += printStruct(dns.answer[i]) + "\n"
+ }
+ }
+ if len(dns.ns) > 0 {
+ s += "-- Name servers\n"
+ for i := 0; i < len(dns.ns); i++ {
+ s += printStruct(dns.ns[i]) + "\n"
+ }
+ }
+ if len(dns.extra) > 0 {
+ s += "-- Extra\n"
+ for i := 0; i < len(dns.extra); i++ {
+ s += printStruct(dns.extra[i]) + "\n"
+ }
+ }
+ return s
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The json package implements encoding and decoding of JSON objects as
+// defined in RFC 4627.
+package json
+
+import (
+ "bytes"
+ "encoding/base64"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strconv"
+ "unicode"
+ "utf8"
+)
+
+// Marshal returns the JSON encoding of v.
+//
+// Marshal traverses the value v recursively.
+// If an encountered value implements the Marshaler interface,
+// Marshal calls its MarshalJSON method to produce JSON.
+//
+// Otherwise, Marshal uses the following type-dependent default encodings:
+//
+// Boolean values encode as JSON booleans.
+//
+// Floating point and integer values encode as JSON numbers.
+//
+// String values encode as JSON strings, with each invalid UTF-8 sequence
+// replaced by the encoding of the Unicode replacement character U+FFFD.
+//
+// Array and slice values encode as JSON arrays, except that
+// []byte encodes as a base64-encoded string.
+//
+// Struct values encode as JSON objects. Each struct field becomes
+// a member of the object. By default the object's key name is the
+// struct field name. If the struct field has a non-empty tag consisting
+// of only Unicode letters, digits, and underscores, that tag will be used
+// as the name instead. Only exported fields will be encoded.
+//
+// Map values encode as JSON objects.
+// The map's key type must be string; the object keys are used directly
+// as map keys.
+//
+// Pointer values encode as the value pointed to.
+// A nil pointer encodes as the null JSON object.
+//
+// Interface values encode as the value contained in the interface.
+// A nil interface value encodes as the null JSON object.
+//
+// Channel, complex, and function values cannot be encoded in JSON.
+// Attempting to encode such a value causes Marshal to return
+// an InvalidTypeError.
+//
+// JSON cannot represent cyclic data structures and Marshal does not
+// handle them. Passing cyclic structures to Marshal will result in
+// an infinite recursion.
+//
+func Marshal(v interface{}) ([]byte, os.Error) {
+ e := &encodeState{}
+ err := e.marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ return e.Bytes(), nil
+}
+
+// MarshalIndent is like Marshal but applies Indent to format the output.
+func MarshalIndent(v interface{}, prefix, indent string) ([]byte, os.Error) {
+ b, err := Marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ var buf bytes.Buffer
+ err = Indent(&buf, b, prefix, indent)
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+// MarshalForHTML is like Marshal but applies HTMLEscape to the output.
+func MarshalForHTML(v interface{}) ([]byte, os.Error) {
+ b, err := Marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ var buf bytes.Buffer
+ HTMLEscape(&buf, b)
+ return buf.Bytes(), nil
+}
+
+// HTMLEscape appends to dst the JSON-encoded src with <, >, and &
+// characters inside string literals changed to \u003c, \u003e, \u0026
+// so that the JSON will be safe to embed inside HTML <script> tags.
+// For historical reasons, web browsers don't honor standard HTML
+// escaping within <script> tags, so an alternative JSON encoding must
+// be used.
+func HTMLEscape(dst *bytes.Buffer, src []byte) {
+ // < > & can only appear in string literals,
+ // so just scan the string one byte at a time.
+ start := 0
+ for i, c := range src {
+ if c == '<' || c == '>' || c == '&' {
+ if start < i {
+ dst.Write(src[start:i])
+ }
+ dst.WriteString(`\u00`)
+ dst.WriteByte(hex[c>>4])
+ dst.WriteByte(hex[c&0xF])
+ start = i + 1
+ }
+ }
+ if start < len(src) {
+ dst.Write(src[start:])
+ }
+}
+
+// Marshaler is the interface implemented by objects that
+// can marshal themselves into valid JSON.
+type Marshaler interface {
+ MarshalJSON() ([]byte, os.Error)
+}
+
+type UnsupportedTypeError struct {
+ Type reflect.Type
+}
+
+func (e *UnsupportedTypeError) String() string {
+ return "json: unsupported type: " + e.Type.String()
+}
+
+type InvalidUTF8Error struct {
+ S string
+}
+
+func (e *InvalidUTF8Error) String() string {
+ return "json: invalid UTF-8 in string: " + strconv.Quote(e.S)
+}
+
+type MarshalerError struct {
+ Type reflect.Type
+ Error os.Error
+}
+
+func (e *MarshalerError) String() string {
+ return "json: error calling MarshalJSON for type " + e.Type.String() + ": " + e.Error.String()
+}
+
+type interfaceOrPtrValue interface {
+ IsNil() bool
+ Elem() reflect.Value
+}
+
+var hex = "0123456789abcdef"
+
+// An encodeState encodes JSON into a bytes.Buffer.
+type encodeState struct {
+ bytes.Buffer // accumulated output
+}
+
+func (e *encodeState) marshal(v interface{}) (err os.Error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(runtime.Error); ok {
+ panic(r)
+ }
+ err = r.(os.Error)
+ }
+ }()
+ e.reflectValue(reflect.NewValue(v))
+ return nil
+}
+
+func (e *encodeState) error(err os.Error) {
+ panic(err)
+}
+
+var byteSliceType = reflect.Typeof([]byte(nil))
+
+func (e *encodeState) reflectValue(v reflect.Value) {
+ if v == nil {
+ e.WriteString("null")
+ return
+ }
+
+ if j, ok := v.Interface().(Marshaler); ok {
+ b, err := j.MarshalJSON()
+ if err == nil {
+ // copy JSON into buffer, checking validity.
+ err = Compact(&e.Buffer, b)
+ }
+ if err != nil {
+ e.error(&MarshalerError{v.Type(), err})
+ }
+ return
+ }
+
+ switch v := v.(type) {
+ case *reflect.BoolValue:
+ x := v.Get()
+ if x {
+ e.WriteString("true")
+ } else {
+ e.WriteString("false")
+ }
+
+ case *reflect.IntValue:
+ e.WriteString(strconv.Itoa64(v.Get()))
+
+ case *reflect.UintValue:
+ e.WriteString(strconv.Uitoa64(v.Get()))
+
+ case *reflect.FloatValue:
+ e.WriteString(strconv.FtoaN(v.Get(), 'g', -1, v.Type().Bits()))
+
+ case *reflect.StringValue:
+ e.string(v.Get())
+
+ case *reflect.StructValue:
+ e.WriteByte('{')
+ t := v.Type().(*reflect.StructType)
+ n := v.NumField()
+ first := true
+ for i := 0; i < n; i++ {
+ f := t.Field(i)
+ if f.PkgPath != "" {
+ continue
+ }
+ if first {
+ first = false
+ } else {
+ e.WriteByte(',')
+ }
+ if isValidTag(f.Tag) {
+ e.string(f.Tag)
+ } else {
+ e.string(f.Name)
+ }
+ e.WriteByte(':')
+ e.reflectValue(v.Field(i))
+ }
+ e.WriteByte('}')
+
+ case *reflect.MapValue:
+ if _, ok := v.Type().(*reflect.MapType).Key().(*reflect.StringType); !ok {
+ e.error(&UnsupportedTypeError{v.Type()})
+ }
+ if v.IsNil() {
+ e.WriteString("null")
+ break
+ }
+ e.WriteByte('{')
+ var sv stringValues = v.Keys()
+ sort.Sort(sv)
+ for i, k := range sv {
+ if i > 0 {
+ e.WriteByte(',')
+ }
+ e.string(k.(*reflect.StringValue).Get())
+ e.WriteByte(':')
+ e.reflectValue(v.Elem(k))
+ }
+ e.WriteByte('}')
+
+ case reflect.ArrayOrSliceValue:
+ if v.Type() == byteSliceType {
+ e.WriteByte('"')
+ s := v.Interface().([]byte)
+ if len(s) < 1024 {
+ // for small buffers, using Encode directly is much faster.
+ dst := make([]byte, base64.StdEncoding.EncodedLen(len(s)))
+ base64.StdEncoding.Encode(dst, s)
+ e.Write(dst)
+ } else {
+ // for large buffers, avoid unnecessary extra temporary
+ // buffer space.
+ enc := base64.NewEncoder(base64.StdEncoding, e)
+ enc.Write(s)
+ enc.Close()
+ }
+ e.WriteByte('"')
+ break
+ }
+ e.WriteByte('[')
+ n := v.Len()
+ for i := 0; i < n; i++ {
+ if i > 0 {
+ e.WriteByte(',')
+ }
+ e.reflectValue(v.Elem(i))
+ }
+ e.WriteByte(']')
+
+ case interfaceOrPtrValue:
+ if v.IsNil() {
+ e.WriteString("null")
+ return
+ }
+ e.reflectValue(v.Elem())
+
+ default:
+ e.error(&UnsupportedTypeError{v.Type()})
+ }
+ return
+}
+
+func isValidTag(s string) bool {
+ if s == "" {
+ return false
+ }
+ for _, c := range s {
+ if c != '_' && !unicode.IsLetter(c) && !unicode.IsDigit(c) {
+ return false
+ }
+ }
+ return true
+}
+
+// stringValues is a slice of reflect.Value holding *reflect.StringValue.
+// It implements the methods to sort by string.
+type stringValues []reflect.Value
+
+func (sv stringValues) Len() int { return len(sv) }
+func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
+func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
+func (sv stringValues) get(i int) string { return sv[i].(*reflect.StringValue).Get() }
+
+func (e *encodeState) string(s string) {
+ e.WriteByte('"')
+ start := 0
+ for i := 0; i < len(s); {
+ if b := s[i]; b < utf8.RuneSelf {
+ if 0x20 <= b && b != '\\' && b != '"' {
+ i++
+ continue
+ }
+ if start < i {
+ e.WriteString(s[start:i])
+ }
+ if b == '\\' || b == '"' {
+ e.WriteByte('\\')
+ e.WriteByte(b)
+ } else {
+ e.WriteString(`\u00`)
+ e.WriteByte(hex[b>>4])
+ e.WriteByte(hex[b&0xF])
+ }
+ i++
+ start = i
+ continue
+ }
+ c, size := utf8.DecodeRuneInString(s[i:])
+ if c == utf8.RuneError && size == 1 {
+ e.error(&InvalidUTF8Error{s})
+ }
+ i += size
+ }
+ if start < len(s) {
+ e.WriteString(s[start:])
+ }
+ e.WriteByte('"')
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The json package implements encoding and decoding of JSON objects as
+// defined in RFC 4627.
+package json
+
+import (
+ "bytes"
+ "encoding/base64"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strconv"
+ "unicode"
+ "utf8"
+)
+
+// Marshal returns the JSON encoding of v.
+//
+// Marshal traverses the value v recursively.
+// If an encountered value implements the Marshaler interface,
+// Marshal calls its MarshalJSON method to produce JSON.
+//
+// Otherwise, Marshal uses the following type-dependent default encodings:
+//
+// Boolean values encode as JSON booleans.
+//
+// Floating point and integer values encode as JSON numbers.
+//
+// String values encode as JSON strings, with each invalid UTF-8 sequence
+// replaced by the encoding of the Unicode replacement character U+FFFD.
+//
+// Array and slice values encode as JSON arrays, except that
+// []byte encodes as a base64-encoded string.
+//
+// Struct values encode as JSON objects. Each struct field becomes
+// a member of the object. By default the object's key name is the
+// struct field name. If the struct field has a non-empty tag consisting
+// of only Unicode letters, digits, and underscores, that tag will be used
+// as the name instead. Only exported fields will be encoded.
+//
+// Map values encode as JSON objects.
+// The map's key type must be string; the object keys are used directly
+// as map keys.
+//
+// Pointer values encode as the value pointed to.
+// A nil pointer encodes as the null JSON object.
+//
+// Interface values encode as the value contained in the interface.
+// A nil interface value encodes as the null JSON object.
+//
+// Channel, complex, and function values cannot be encoded in JSON.
+// Attempting to encode such a value causes Marshal to return
+// an InvalidTypeError.
+//
+// JSON cannot represent cyclic data structures and Marshal does not
+// handle them. Passing cyclic structures to Marshal will result in
+// an infinite recursion.
+//
+func Marshal(v interface{}) ([]byte, os.Error) {
+ e := &encodeState{}
+ err := e.marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ return e.Bytes(), nil
+}
+
+// MarshalIndent is like Marshal but applies Indent to format the output.
+func MarshalIndent(v interface{}, prefix, indent string) ([]byte, os.Error) {
+ b, err := Marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ var buf bytes.Buffer
+ err = Indent(&buf, b, prefix, indent)
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+// MarshalForHTML is like Marshal but applies HTMLEscape to the output.
+func MarshalForHTML(v interface{}) ([]byte, os.Error) {
+ b, err := Marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ var buf bytes.Buffer
+ HTMLEscape(&buf, b)
+ return buf.Bytes(), nil
+}
+
+// HTMLEscape appends to dst the JSON-encoded src with <, >, and &
+// characters inside string literals changed to \u003c, \u003e, \u0026
+// so that the JSON will be safe to embed inside HTML <script> tags.
+// For historical reasons, web browsers don't honor standard HTML
+// escaping within <script> tags, so an alternative JSON encoding must
+// be used.
+func HTMLEscape(dst *bytes.Buffer, src []byte) {
+ // < > & can only appear in string literals,
+ // so just scan the string one byte at a time.
+ start := 0
+ for i, c := range src {
+ if c == '<' || c == '>' || c == '&' {
+ if start < i {
+ dst.Write(src[start:i])
+ }
+ dst.WriteString(`\u00`)
+ dst.WriteByte(hex[c>>4])
+ dst.WriteByte(hex[c&0xF])
+ start = i + 1
+ }
+ }
+ if start < len(src) {
+ dst.Write(src[start:])
+ }
+}
+
+// Marshaler is the interface implemented by objects that
+// can marshal themselves into valid JSON.
+type Marshaler interface {
+ MarshalJSON() ([]byte, os.Error)
+}
+
+type UnsupportedTypeError struct {
+ Type reflect.Type
+}
+
+func (e *UnsupportedTypeError) String() string {
+ return "json: unsupported type: " + e.Type.String()
+}
+
+type InvalidUTF8Error struct {
+ S string
+}
+
+func (e *InvalidUTF8Error) String() string {
+ return "json: invalid UTF-8 in string: " + strconv.Quote(e.S)
+}
+
+type MarshalerError struct {
+ Type reflect.Type
+ Error os.Error
+}
+
+func (e *MarshalerError) String() string {
+ return "json: error calling MarshalJSON for type " + e.Type.String() + ": " + e.Error.String()
+}
+
+type interfaceOrPtrValue interface {
+ IsNil() bool
+ Elem() reflect.Value
+}
+
+var hex = "0123456789abcdef"
+
+// An encodeState encodes JSON into a bytes.Buffer.
+type encodeState struct {
+ bytes.Buffer // accumulated output
+}
+
+func (e *encodeState) marshal(v interface{}) (err os.Error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(runtime.Error); ok {
+ panic(r)
+ }
+ err = r.(os.Error)
+ }
+ }()
+ e.reflectValue(reflect.NewValue(v))
+ return nil
+}
+
+func (e *encodeState) error(err os.Error) {
+ panic(err)
+}
+
+var byteSliceType = reflect.Typeof([]byte(nil))
+
+func (e *encodeState) reflectValue(v reflect.Value) {
+ if !v.IsValid() {
+ e.WriteString("null")
+ return
+ }
+
+ if j, ok := v.Interface().(Marshaler); ok {
+ b, err := j.MarshalJSON()
+ if err == nil {
+ // copy JSON into buffer, checking validity.
+ err = Compact(&e.Buffer, b)
+ }
+ if err != nil {
+ e.error(&MarshalerError{v.Type(), err})
+ }
+ return
+ }
+
+ switch v.Kind() {
+ case reflect.Bool:
+ x := v.Bool()
+ if x {
+ e.WriteString("true")
+ } else {
+ e.WriteString("false")
+ }
+
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ e.WriteString(strconv.Itoa64(v.Int()))
+
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ e.WriteString(strconv.Uitoa64(v.Uint()))
+
+ case reflect.Float32, reflect.Float64:
+ e.WriteString(strconv.FtoaN(v.Float(), 'g', -1, v.Type().Bits()))
+
+ case reflect.String:
+ e.string(v.String())
+
+ case reflect.Struct:
+ e.WriteByte('{')
+ t := v.Type()
+ n := v.NumField()
+ first := true
+ for i := 0; i < n; i++ {
+ f := t.Field(i)
+ if f.PkgPath != "" {
+ continue
+ }
+ if first {
+ first = false
+ } else {
+ e.WriteByte(',')
+ }
+ if isValidTag(f.Tag) {
+ e.string(f.Tag)
+ } else {
+ e.string(f.Name)
+ }
+ e.WriteByte(':')
+ e.reflectValue(v.Field(i))
+ }
+ e.WriteByte('}')
+
+ case reflect.Map:
+ if v.Type().Key().Kind() != reflect.String {
+ e.error(&UnsupportedTypeError{v.Type()})
+ }
+ if v.IsNil() {
+ e.WriteString("null")
+ break
+ }
+ e.WriteByte('{')
+ var sv stringValues = v.MapKeys()
+ sort.Sort(sv)
+ for i, k := range sv {
+ if i > 0 {
+ e.WriteByte(',')
+ }
+ e.string(k.String())
+ e.WriteByte(':')
+ e.reflectValue(v.MapIndex(k))
+ }
+ e.WriteByte('}')
+
+ case reflect.Array, reflect.Slice:
+ if v.Type() == byteSliceType {
+ e.WriteByte('"')
+ s := v.Interface().([]byte)
+ if len(s) < 1024 {
+ // for small buffers, using Encode directly is much faster.
+ dst := make([]byte, base64.StdEncoding.EncodedLen(len(s)))
+ base64.StdEncoding.Encode(dst, s)
+ e.Write(dst)
+ } else {
+ // for large buffers, avoid unnecessary extra temporary
+ // buffer space.
+ enc := base64.NewEncoder(base64.StdEncoding, e)
+ enc.Write(s)
+ enc.Close()
+ }
+ e.WriteByte('"')
+ break
+ }
+ e.WriteByte('[')
+ n := v.Len()
+ for i := 0; i < n; i++ {
+ if i > 0 {
+ e.WriteByte(',')
+ }
+ e.reflectValue(v.Index(i))
+ }
+ e.WriteByte(']')
+
+ case interfaceOrPtrValue:
+ if v.IsNil() {
+ e.WriteString("null")
+ return
+ }
+ e.reflectValue(v.Elem())
+
+ default:
+ e.error(&UnsupportedTypeError{v.Type()})
+ }
+ return
+}
+
+func isValidTag(s string) bool {
+ if s == "" {
+ return false
+ }
+ for _, c := range s {
+ if c != '_' && !unicode.IsLetter(c) && !unicode.IsDigit(c) {
+ return false
+ }
+ }
+ return true
+}
+
+// stringValues is a slice of reflect.Value holding *reflect.StringValue.
+// It implements the methods to sort by string.
+type stringValues []reflect.Value
+
+func (sv stringValues) Len() int { return len(sv) }
+func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
+func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
+func (sv stringValues) get(i int) string { return sv[i].String() }
+
+func (e *encodeState) string(s string) {
+ e.WriteByte('"')
+ start := 0
+ for i := 0; i < len(s); {
+ if b := s[i]; b < utf8.RuneSelf {
+ if 0x20 <= b && b != '\\' && b != '"' {
+ i++
+ continue
+ }
+ if start < i {
+ e.WriteString(s[start:i])
+ }
+ if b == '\\' || b == '"' {
+ e.WriteByte('\\')
+ e.WriteByte(b)
+ } else {
+ e.WriteString(`\u00`)
+ e.WriteByte(hex[b>>4])
+ e.WriteByte(hex[b&0xF])
+ }
+ i++
+ start = i
+ continue
+ }
+ c, size := utf8.DecodeRuneInString(s[i:])
+ if c == utf8.RuneError && size == 1 {
+ e.error(&InvalidUTF8Error{s})
+ }
+ i += size
+ }
+ if start < len(s) {
+ e.WriteString(s[start:])
+ }
+ e.WriteByte('"')
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "sync"
+)
+
+// An Encoder manages the transmission of type and data information to the
+// other side of a connection.
+type Encoder struct {
+ mutex sync.Mutex // each item must be sent atomically
+ w []io.Writer // where to send the data
+ sent map[reflect.Type]typeId // which types we've already sent
+ countState *encoderState // stage for writing counts
+ freeList *encoderState // list of free encoderStates; avoids reallocation
+ buf []byte // for collecting the output.
+ byteBuf bytes.Buffer // buffer for top-level encoderState
+ err os.Error
+}
+
+// NewEncoder returns a new encoder that will transmit on the io.Writer.
+func NewEncoder(w io.Writer) *Encoder {
+ enc := new(Encoder)
+ enc.w = []io.Writer{w}
+ enc.sent = make(map[reflect.Type]typeId)
+ enc.countState = enc.newEncoderState(new(bytes.Buffer))
+ return enc
+}
+
+// writer() returns the innermost writer the encoder is using
+func (enc *Encoder) writer() io.Writer {
+ return enc.w[len(enc.w)-1]
+}
+
+// pushWriter adds a writer to the encoder.
+func (enc *Encoder) pushWriter(w io.Writer) {
+ enc.w = append(enc.w, w)
+}
+
+// popWriter pops the innermost writer.
+func (enc *Encoder) popWriter() {
+ enc.w = enc.w[0 : len(enc.w)-1]
+}
+
+func (enc *Encoder) badType(rt reflect.Type) {
+ enc.setError(os.ErrorString("gob: can't encode type " + rt.String()))
+}
+
+func (enc *Encoder) setError(err os.Error) {
+ if enc.err == nil { // remember the first.
+ enc.err = err
+ }
+}
+
+// writeMessage sends the data item preceded by a unsigned count of its length.
+func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
+ enc.countState.encodeUint(uint64(b.Len()))
+ // Build the buffer.
+ countLen := enc.countState.b.Len()
+ total := countLen + b.Len()
+ if total > len(enc.buf) {
+ enc.buf = make([]byte, total+1000) // extra for growth
+ }
+ // Place the length before the data.
+ // TODO(r): avoid the extra copy here.
+ enc.countState.b.Read(enc.buf[0:countLen])
+ // Now the data.
+ b.Read(enc.buf[countLen:total])
+ // Write the data.
+ _, err := w.Write(enc.buf[0:total])
+ if err != nil {
+ enc.setError(err)
+ }
+}
+
+// sendActualType sends the requested type, without further investigation, unless
+// it's been sent before.
+func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
+ if _, alreadySent := enc.sent[actual]; alreadySent {
+ return false
+ }
+ typeLock.Lock()
+ info, err := getTypeInfo(ut)
+ typeLock.Unlock()
+ if err != nil {
+ enc.setError(err)
+ return
+ }
+ // Send the pair (-id, type)
+ // Id:
+ state.encodeInt(-int64(info.id))
+ // Type:
+ enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
+ enc.writeMessage(w, state.b)
+ if enc.err != nil {
+ return
+ }
+
+ // Remember we've sent this type, both what the user gave us and the base type.
+ enc.sent[ut.base] = info.id
+ if ut.user != ut.base {
+ enc.sent[ut.user] = info.id
+ }
+ // Now send the inner types
+ switch st := actual.(type) {
+ case *reflect.StructType:
+ for i := 0; i < st.NumField(); i++ {
+ enc.sendType(w, state, st.Field(i).Type)
+ }
+ case reflect.ArrayOrSliceType:
+ enc.sendType(w, state, st.Elem())
+ }
+ return true
+}
+
+// sendType sends the type info to the other side, if necessary.
+func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
+ ut := userType(origt)
+ if ut.isGobEncoder {
+ // The rules are different: regardless of the underlying type's representation,
+ // we need to tell the other side that this exact type is a GobEncoder.
+ return enc.sendActualType(w, state, ut, ut.user)
+ }
+
+ // It's a concrete value, so drill down to the base type.
+ switch rt := ut.base.(type) {
+ default:
+ // Basic types and interfaces do not need to be described.
+ return
+ case *reflect.SliceType:
+ // If it's []uint8, don't send; it's considered basic.
+ if rt.Elem().Kind() == reflect.Uint8 {
+ return
+ }
+ // Otherwise we do send.
+ break
+ case *reflect.ArrayType:
+ // arrays must be sent so we know their lengths and element types.
+ break
+ case *reflect.MapType:
+ // maps must be sent so we know their lengths and key/value types.
+ break
+ case *reflect.StructType:
+ // structs must be sent so we know their fields.
+ break
+ case *reflect.ChanType, *reflect.FuncType:
+ // Probably a bad field in a struct.
+ enc.badType(rt)
+ return
+ }
+
+ return enc.sendActualType(w, state, ut, ut.base)
+}
+
+// Encode transmits the data item represented by the empty interface value,
+// guaranteeing that all necessary type information has been transmitted first.
+func (enc *Encoder) Encode(e interface{}) os.Error {
+ return enc.EncodeValue(reflect.NewValue(e))
+}
+
+// sendTypeDescriptor makes sure the remote side knows about this type.
+// It will send a descriptor if this is the first time the type has been
+// sent.
+func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
+ // Make sure the type is known to the other side.
+ // First, have we already sent this type?
+ rt := ut.base
+ if ut.isGobEncoder {
+ rt = ut.user
+ }
+ if _, alreadySent := enc.sent[rt]; !alreadySent {
+ // No, so send it.
+ sent := enc.sendType(w, state, rt)
+ if enc.err != nil {
+ return
+ }
+ // If the type info has still not been transmitted, it means we have
+ // a singleton basic type (int, []byte etc.) at top level. We don't
+ // need to send the type info but we do need to update enc.sent.
+ if !sent {
+ typeLock.Lock()
+ info, err := getTypeInfo(ut)
+ typeLock.Unlock()
+ if err != nil {
+ enc.setError(err)
+ return
+ }
+ enc.sent[rt] = info.id
+ }
+ }
+}
+
+// sendTypeId sends the id, which must have already been defined.
+func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) {
+ // Identify the type of this top-level value.
+ state.encodeInt(int64(enc.sent[ut.base]))
+}
+
+// EncodeValue transmits the data item represented by the reflection value,
+// guaranteeing that all necessary type information has been transmitted first.
+func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
+ // Make sure we're single-threaded through here, so multiple
+ // goroutines can share an encoder.
+ enc.mutex.Lock()
+ defer enc.mutex.Unlock()
+
+ // Remove any nested writers remaining due to previous errors.
+ enc.w = enc.w[0:1]
+
+ ut, err := validUserType(value.Type())
+ if err != nil {
+ return err
+ }
+
+ enc.err = nil
+ enc.byteBuf.Reset()
+ state := enc.newEncoderState(&enc.byteBuf)
+
+ enc.sendTypeDescriptor(enc.writer(), state, ut)
+ enc.sendTypeId(state, ut)
+ if enc.err != nil {
+ return enc.err
+ }
+
+ // Encode the object.
+ enc.encode(state.b, value, ut)
+ if enc.err == nil {
+ enc.writeMessage(enc.writer(), state.b)
+ }
+
+ enc.freeEncoderState(state)
+ return enc.err
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "sync"
+)
+
+// An Encoder manages the transmission of type and data information to the
+// other side of a connection.
+type Encoder struct {
+ mutex sync.Mutex // each item must be sent atomically
+ w []io.Writer // where to send the data
+ sent map[reflect.Type]typeId // which types we've already sent
+ countState *encoderState // stage for writing counts
+ freeList *encoderState // list of free encoderStates; avoids reallocation
+ buf []byte // for collecting the output.
+ byteBuf bytes.Buffer // buffer for top-level encoderState
+ err os.Error
+}
+
+// NewEncoder returns a new encoder that will transmit on the io.Writer.
+func NewEncoder(w io.Writer) *Encoder {
+ enc := new(Encoder)
+ enc.w = []io.Writer{w}
+ enc.sent = make(map[reflect.Type]typeId)
+ enc.countState = enc.newEncoderState(new(bytes.Buffer))
+ return enc
+}
+
+// writer() returns the innermost writer the encoder is using
+func (enc *Encoder) writer() io.Writer {
+ return enc.w[len(enc.w)-1]
+}
+
+// pushWriter adds a writer to the encoder.
+func (enc *Encoder) pushWriter(w io.Writer) {
+ enc.w = append(enc.w, w)
+}
+
+// popWriter pops the innermost writer.
+func (enc *Encoder) popWriter() {
+ enc.w = enc.w[0 : len(enc.w)-1]
+}
+
+func (enc *Encoder) badType(rt reflect.Type) {
+ enc.setError(os.ErrorString("gob: can't encode type " + rt.String()))
+}
+
+func (enc *Encoder) setError(err os.Error) {
+ if enc.err == nil { // remember the first.
+ enc.err = err
+ }
+}
+
+// writeMessage sends the data item preceded by a unsigned count of its length.
+func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
+ enc.countState.encodeUint(uint64(b.Len()))
+ // Build the buffer.
+ countLen := enc.countState.b.Len()
+ total := countLen + b.Len()
+ if total > len(enc.buf) {
+ enc.buf = make([]byte, total+1000) // extra for growth
+ }
+ // Place the length before the data.
+ // TODO(r): avoid the extra copy here.
+ enc.countState.b.Read(enc.buf[0:countLen])
+ // Now the data.
+ b.Read(enc.buf[countLen:total])
+ // Write the data.
+ _, err := w.Write(enc.buf[0:total])
+ if err != nil {
+ enc.setError(err)
+ }
+}
+
+// sendActualType sends the requested type, without further investigation, unless
+// it's been sent before.
+func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
+ if _, alreadySent := enc.sent[actual]; alreadySent {
+ return false
+ }
+ typeLock.Lock()
+ info, err := getTypeInfo(ut)
+ typeLock.Unlock()
+ if err != nil {
+ enc.setError(err)
+ return
+ }
+ // Send the pair (-id, type)
+ // Id:
+ state.encodeInt(-int64(info.id))
+ // Type:
+ enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
+ enc.writeMessage(w, state.b)
+ if enc.err != nil {
+ return
+ }
+
+ // Remember we've sent this type, both what the user gave us and the base type.
+ enc.sent[ut.base] = info.id
+ if ut.user != ut.base {
+ enc.sent[ut.user] = info.id
+ }
+ // Now send the inner types
+ switch st := actual; st.Kind() {
+ case reflect.Struct:
+ for i := 0; i < st.NumField(); i++ {
+ enc.sendType(w, state, st.Field(i).Type)
+ }
+ case reflect.Array, reflect.Slice:
+ enc.sendType(w, state, st.Elem())
+ }
+ return true
+}
+
+// sendType sends the type info to the other side, if necessary.
+func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
+ ut := userType(origt)
+ if ut.isGobEncoder {
+ // The rules are different: regardless of the underlying type's representation,
+ // we need to tell the other side that this exact type is a GobEncoder.
+ return enc.sendActualType(w, state, ut, ut.user)
+ }
+
+ // It's a concrete value, so drill down to the base type.
+ switch rt := ut.base; rt.Kind() {
+ default:
+ // Basic types and interfaces do not need to be described.
+ return
+ case reflect.Slice:
+ // If it's []uint8, don't send; it's considered basic.
+ if rt.Elem().Kind() == reflect.Uint8 {
+ return
+ }
+ // Otherwise we do send.
+ break
+ case reflect.Array:
+ // arrays must be sent so we know their lengths and element types.
+ break
+ case reflect.Map:
+ // maps must be sent so we know their lengths and key/value types.
+ break
+ case reflect.Struct:
+ // structs must be sent so we know their fields.
+ break
+ case reflect.Chan, reflect.Func:
+ // Probably a bad field in a struct.
+ enc.badType(rt)
+ return
+ }
+
+ return enc.sendActualType(w, state, ut, ut.base)
+}
+
+// Encode transmits the data item represented by the empty interface value,
+// guaranteeing that all necessary type information has been transmitted first.
+func (enc *Encoder) Encode(e interface{}) os.Error {
+ return enc.EncodeValue(reflect.NewValue(e))
+}
+
+// sendTypeDescriptor makes sure the remote side knows about this type.
+// It will send a descriptor if this is the first time the type has been
+// sent.
+func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
+ // Make sure the type is known to the other side.
+ // First, have we already sent this type?
+ rt := ut.base
+ if ut.isGobEncoder {
+ rt = ut.user
+ }
+ if _, alreadySent := enc.sent[rt]; !alreadySent {
+ // No, so send it.
+ sent := enc.sendType(w, state, rt)
+ if enc.err != nil {
+ return
+ }
+ // If the type info has still not been transmitted, it means we have
+ // a singleton basic type (int, []byte etc.) at top level. We don't
+ // need to send the type info but we do need to update enc.sent.
+ if !sent {
+ typeLock.Lock()
+ info, err := getTypeInfo(ut)
+ typeLock.Unlock()
+ if err != nil {
+ enc.setError(err)
+ return
+ }
+ enc.sent[rt] = info.id
+ }
+ }
+}
+
+// sendTypeId sends the id, which must have already been defined.
+func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) {
+ // Identify the type of this top-level value.
+ state.encodeInt(int64(enc.sent[ut.base]))
+}
+
+// EncodeValue transmits the data item represented by the reflection value,
+// guaranteeing that all necessary type information has been transmitted first.
+func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
+ // Make sure we're single-threaded through here, so multiple
+ // goroutines can share an encoder.
+ enc.mutex.Lock()
+ defer enc.mutex.Unlock()
+
+ // Remove any nested writers remaining due to previous errors.
+ enc.w = enc.w[0:1]
+
+ ut, err := validUserType(value.Type())
+ if err != nil {
+ return err
+ }
+
+ enc.err = nil
+ enc.byteBuf.Reset()
+ state := enc.newEncoderState(&enc.byteBuf)
+
+ enc.sendTypeDescriptor(enc.writer(), state, ut)
+ enc.sendTypeId(state, ut)
+ if enc.err != nil {
+ return enc.err
+ }
+
+ // Encode the object.
+ enc.encode(state.b, value, ut)
+ if enc.err == nil {
+ enc.writeMessage(enc.writer(), state.b)
+ }
+
+ enc.freeEncoderState(state)
+ return enc.err
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ The netchan package implements type-safe networked channels:
+ it allows the two ends of a channel to appear on different
+ computers connected by a network. It does this by transporting
+ data sent to a channel on one machine so it can be recovered
+ by a receive of a channel of the same type on the other.
+
+ An exporter publishes a set of channels by name. An importer
+ connects to the exporting machine and imports the channels
+ by name. After importing the channels, the two machines can
+ use the channels in the usual way.
+
+ Networked channels are not synchronized; they always behave
+ as if they are buffered channels of at least one element.
+*/
+package netchan
+
+// BUG: can't use range clause to receive when using ImportNValues to limit the count.
+
+import (
+ "log"
+ "io"
+ "net"
+ "os"
+ "reflect"
+ "strconv"
+ "sync"
+)
+
+// Export
+
+// expLog is a logging convenience function. The first argument must be a string.
+func expLog(args ...interface{}) {
+ args[0] = "netchan export: " + args[0].(string)
+ log.Print(args...)
+}
+
+// An Exporter allows a set of channels to be published on a single
+// network port. A single machine may have multiple Exporters
+// but they must use different ports.
+type Exporter struct {
+ *clientSet
+}
+
+type expClient struct {
+ *encDec
+ exp *Exporter
+ chans map[int]*netChan // channels in use by client
+ mu sync.Mutex // protects remaining fields
+ errored bool // client has been sent an error
+ seqNum int64 // sequences messages sent to client; has value of highest sent
+ ackNum int64 // highest sequence number acknowledged
+ seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
+}
+
+func newClient(exp *Exporter, conn io.ReadWriter) *expClient {
+ client := new(expClient)
+ client.exp = exp
+ client.encDec = newEncDec(conn)
+ client.seqNum = 0
+ client.ackNum = 0
+ client.chans = make(map[int]*netChan)
+ return client
+}
+
+func (client *expClient) sendError(hdr *header, err string) {
+ error := &error{err}
+ expLog("sending error to client:", error.Error)
+ client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
+ client.mu.Lock()
+ client.errored = true
+ client.mu.Unlock()
+}
+
+func (client *expClient) newChan(hdr *header, dir Dir, name string, size int, count int64) *netChan {
+ exp := client.exp
+ exp.mu.Lock()
+ ech, ok := exp.names[name]
+ exp.mu.Unlock()
+ if !ok {
+ client.sendError(hdr, "no such channel: "+name)
+ return nil
+ }
+ if ech.dir != dir {
+ client.sendError(hdr, "wrong direction for channel: "+name)
+ return nil
+ }
+ nch := newNetChan(name, hdr.Id, ech, client.encDec, size, count)
+ client.chans[hdr.Id] = nch
+ return nch
+}
+
+func (client *expClient) getChan(hdr *header, dir Dir) *netChan {
+ nch := client.chans[hdr.Id]
+ if nch == nil {
+ return nil
+ }
+ if nch.dir != dir {
+ client.sendError(hdr, "wrong direction for channel: "+nch.name)
+ }
+ return nch
+}
+
+// The function run manages sends and receives for a single client. For each
+// (client Recv) request, this will launch a serveRecv goroutine to deliver
+// the data for that channel, while (client Send) requests are handled as
+// data arrives from the client.
+func (client *expClient) run() {
+ hdr := new(header)
+ hdrValue := reflect.NewValue(hdr)
+ req := new(request)
+ reqValue := reflect.NewValue(req)
+ error := new(error)
+ for {
+ *hdr = header{}
+ if err := client.decode(hdrValue); err != nil {
+ if err != os.EOF {
+ expLog("error decoding client header:", err)
+ }
+ break
+ }
+ switch hdr.PayloadType {
+ case payRequest:
+ *req = request{}
+ if err := client.decode(reqValue); err != nil {
+ expLog("error decoding client request:", err)
+ break
+ }
+ if req.Size < 1 {
+ panic("netchan: remote requested " + strconv.Itoa(req.Size) + " values")
+ }
+ switch req.Dir {
+ case Recv:
+ // look up channel before calling serveRecv to
+ // avoid a lock around client.chans.
+ if nch := client.newChan(hdr, Send, req.Name, req.Size, req.Count); nch != nil {
+ go client.serveRecv(nch, *hdr, req.Count)
+ }
+ case Send:
+ client.newChan(hdr, Recv, req.Name, req.Size, req.Count)
+ // The actual sends will have payload type payData.
+ // TODO: manage the count?
+ default:
+ error.Error = "request: can't handle channel direction"
+ expLog(error.Error, req.Dir)
+ client.encode(hdr, payError, error)
+ }
+ case payData:
+ client.serveSend(*hdr)
+ case payClosed:
+ client.serveClosed(*hdr)
+ case payAck:
+ client.mu.Lock()
+ if client.ackNum != hdr.SeqNum-1 {
+ // Since the sequence number is incremented and the message is sent
+ // in a single instance of locking client.mu, the messages are guaranteed
+ // to be sent in order. Therefore receipt of acknowledgement N means
+ // all messages <=N have been seen by the recipient. We check anyway.
+ expLog("sequence out of order:", client.ackNum, hdr.SeqNum)
+ }
+ if client.ackNum < hdr.SeqNum { // If there has been an error, don't back up the count.
+ client.ackNum = hdr.SeqNum
+ }
+ client.mu.Unlock()
+ case payAckSend:
+ if nch := client.getChan(hdr, Send); nch != nil {
+ nch.acked()
+ }
+ default:
+ log.Fatal("netchan export: unknown payload type", hdr.PayloadType)
+ }
+ }
+ client.exp.delClient(client)
+}
+
+// Send all the data on a single channel to a client asking for a Recv.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) {
+ for {
+ val, ok := nch.recv()
+ if !ok {
+ if err := client.encode(&hdr, payClosed, nil); err != nil {
+ expLog("error encoding server closed message:", err)
+ }
+ break
+ }
+ // We hold the lock during transmission to guarantee messages are
+ // sent in sequence number order. Also, we increment first so the
+ // value of client.SeqNum is the value of the highest used sequence
+ // number, not one beyond.
+ client.mu.Lock()
+ client.seqNum++
+ hdr.SeqNum = client.seqNum
+ client.seqLock.Lock() // guarantee ordering of messages
+ client.mu.Unlock()
+ err := client.encode(&hdr, payData, val.Interface())
+ client.seqLock.Unlock()
+ if err != nil {
+ expLog("error encoding client response:", err)
+ client.sendError(&hdr, err.String())
+ break
+ }
+ // Negative count means run forever.
+ if count >= 0 {
+ if count--; count <= 0 {
+ break
+ }
+ }
+ }
+}
+
+// Receive and deliver locally one item from a client asking for a Send
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveSend(hdr header) {
+ nch := client.getChan(&hdr, Recv)
+ if nch == nil {
+ return
+ }
+ // Create a new value for each received item.
+ val := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem())
+ if err := client.decode(val); err != nil {
+ expLog("value decode:", err, "; type ", nch.ch.Type())
+ return
+ }
+ nch.send(val)
+}
+
+// Report that client has closed the channel that is sending to us.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveClosed(hdr header) {
+ nch := client.getChan(&hdr, Recv)
+ if nch == nil {
+ return
+ }
+ nch.close()
+}
+
+func (client *expClient) unackedCount() int64 {
+ client.mu.Lock()
+ n := client.seqNum - client.ackNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) seq() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) ack() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+// Serve waits for incoming connections on the listener
+// and serves the Exporter's channels on each.
+// It blocks until the listener is closed.
+func (exp *Exporter) Serve(listener net.Listener) {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ expLog("listen:", err)
+ break
+ }
+ go exp.ServeConn(conn)
+ }
+}
+
+// ServeConn exports the Exporter's channels on conn.
+// It blocks until the connection is terminated.
+func (exp *Exporter) ServeConn(conn io.ReadWriter) {
+ exp.addClient(conn).run()
+}
+
+// NewExporter creates a new Exporter that exports a set of channels.
+func NewExporter() *Exporter {
+ e := &Exporter{
+ clientSet: &clientSet{
+ names: make(map[string]*chanDir),
+ clients: make(map[unackedCounter]bool),
+ },
+ }
+ return e
+}
+
+// ListenAndServe exports the exporter's channels through the
+// given network and local address defined as in net.Listen.
+func (exp *Exporter) ListenAndServe(network, localaddr string) os.Error {
+ listener, err := net.Listen(network, localaddr)
+ if err != nil {
+ return err
+ }
+ go exp.Serve(listener)
+ return nil
+}
+
+// addClient creates a new expClient and records its existence
+func (exp *Exporter) addClient(conn io.ReadWriter) *expClient {
+ client := newClient(exp, conn)
+ exp.mu.Lock()
+ exp.clients[client] = true
+ exp.mu.Unlock()
+ return client
+}
+
+// delClient forgets the client existed
+func (exp *Exporter) delClient(client *expClient) {
+ exp.mu.Lock()
+ exp.clients[client] = false, false
+ exp.mu.Unlock()
+}
+
+// Drain waits until all messages sent from this exporter/importer, including
+// those not yet sent to any client and possibly including those sent while
+// Drain was executing, have been received by the importer. In short, it
+// waits until all the exporter's messages have been received by a client.
+// If the timeout (measured in nanoseconds) is positive and Drain takes
+// longer than that to complete, an error is returned.
+func (exp *Exporter) Drain(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.drain(timeout)
+}
+
+// Sync waits until all clients of the exporter have received the messages
+// that were sent at the time Sync was invoked. Unlike Drain, it does not
+// wait for messages sent while it is running or messages that have not been
+// dispatched to any client. If the timeout (measured in nanoseconds) is
+// positive and Sync takes longer than that to complete, an error is
+// returned.
+func (exp *Exporter) Sync(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.sync(timeout)
+}
+
+func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) {
+ chanType, ok := reflect.Typeof(chT).(*reflect.ChanType)
+ if !ok {
+ return nil, os.ErrorString("not a channel")
+ }
+ if dir != Send && dir != Recv {
+ return nil, os.ErrorString("unknown channel direction")
+ }
+ switch chanType.Dir() {
+ case reflect.BothDir:
+ case reflect.SendDir:
+ if dir != Recv {
+ return nil, os.ErrorString("to import/export with Send, must provide <-chan")
+ }
+ case reflect.RecvDir:
+ if dir != Send {
+ return nil, os.ErrorString("to import/export with Recv, must provide chan<-")
+ }
+ }
+ return reflect.NewValue(chT).(*reflect.ChanValue), nil
+}
+
+// Export exports a channel of a given type and specified direction. The
+// channel to be exported is provided in the call and may be of arbitrary
+// channel type.
+// Despite the literal signature, the effective signature is
+// Export(name string, chT chan T, dir Dir)
+func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
+ ch, err := checkChan(chT, dir)
+ if err != nil {
+ return err
+ }
+ exp.mu.Lock()
+ defer exp.mu.Unlock()
+ _, present := exp.names[name]
+ if present {
+ return os.ErrorString("channel name already being exported:" + name)
+ }
+ exp.names[name] = &chanDir{ch, dir}
+ return nil
+}
+
+// Hangup disassociates the named channel from the Exporter and closes
+// the channel. Messages in flight for the channel may be dropped.
+func (exp *Exporter) Hangup(name string) os.Error {
+ exp.mu.Lock()
+ chDir, ok := exp.names[name]
+ if ok {
+ exp.names[name] = nil, false
+ }
+ // TODO drop all instances of channel from client sets
+ exp.mu.Unlock()
+ if !ok {
+ return os.ErrorString("netchan export: hangup: no such channel: " + name)
+ }
+ chDir.ch.Close()
+ return nil
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ The netchan package implements type-safe networked channels:
+ it allows the two ends of a channel to appear on different
+ computers connected by a network. It does this by transporting
+ data sent to a channel on one machine so it can be recovered
+ by a receive of a channel of the same type on the other.
+
+ An exporter publishes a set of channels by name. An importer
+ connects to the exporting machine and imports the channels
+ by name. After importing the channels, the two machines can
+ use the channels in the usual way.
+
+ Networked channels are not synchronized; they always behave
+ as if they are buffered channels of at least one element.
+*/
+package netchan
+
+// BUG: can't use range clause to receive when using ImportNValues to limit the count.
+
+import (
+ "log"
+ "io"
+ "net"
+ "os"
+ "reflect"
+ "strconv"
+ "sync"
+)
+
+// Export
+
+// expLog is a logging convenience function. The first argument must be a string.
+func expLog(args ...interface{}) {
+ args[0] = "netchan export: " + args[0].(string)
+ log.Print(args...)
+}
+
+// An Exporter allows a set of channels to be published on a single
+// network port. A single machine may have multiple Exporters
+// but they must use different ports.
+type Exporter struct {
+ *clientSet
+}
+
+type expClient struct {
+ *encDec
+ exp *Exporter
+ chans map[int]*netChan // channels in use by client
+ mu sync.Mutex // protects remaining fields
+ errored bool // client has been sent an error
+ seqNum int64 // sequences messages sent to client; has value of highest sent
+ ackNum int64 // highest sequence number acknowledged
+ seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
+}
+
+func newClient(exp *Exporter, conn io.ReadWriter) *expClient {
+ client := new(expClient)
+ client.exp = exp
+ client.encDec = newEncDec(conn)
+ client.seqNum = 0
+ client.ackNum = 0
+ client.chans = make(map[int]*netChan)
+ return client
+}
+
+func (client *expClient) sendError(hdr *header, err string) {
+ error := &error{err}
+ expLog("sending error to client:", error.Error)
+ client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
+ client.mu.Lock()
+ client.errored = true
+ client.mu.Unlock()
+}
+
+func (client *expClient) newChan(hdr *header, dir Dir, name string, size int, count int64) *netChan {
+ exp := client.exp
+ exp.mu.Lock()
+ ech, ok := exp.names[name]
+ exp.mu.Unlock()
+ if !ok {
+ client.sendError(hdr, "no such channel: "+name)
+ return nil
+ }
+ if ech.dir != dir {
+ client.sendError(hdr, "wrong direction for channel: "+name)
+ return nil
+ }
+ nch := newNetChan(name, hdr.Id, ech, client.encDec, size, count)
+ client.chans[hdr.Id] = nch
+ return nch
+}
+
+func (client *expClient) getChan(hdr *header, dir Dir) *netChan {
+ nch := client.chans[hdr.Id]
+ if nch == nil {
+ return nil
+ }
+ if nch.dir != dir {
+ client.sendError(hdr, "wrong direction for channel: "+nch.name)
+ }
+ return nch
+}
+
+// The function run manages sends and receives for a single client. For each
+// (client Recv) request, this will launch a serveRecv goroutine to deliver
+// the data for that channel, while (client Send) requests are handled as
+// data arrives from the client.
+func (client *expClient) run() {
+ hdr := new(header)
+ hdrValue := reflect.NewValue(hdr)
+ req := new(request)
+ reqValue := reflect.NewValue(req)
+ error := new(error)
+ for {
+ *hdr = header{}
+ if err := client.decode(hdrValue); err != nil {
+ if err != os.EOF {
+ expLog("error decoding client header:", err)
+ }
+ break
+ }
+ switch hdr.PayloadType {
+ case payRequest:
+ *req = request{}
+ if err := client.decode(reqValue); err != nil {
+ expLog("error decoding client request:", err)
+ break
+ }
+ if req.Size < 1 {
+ panic("netchan: remote requested " + strconv.Itoa(req.Size) + " values")
+ }
+ switch req.Dir {
+ case Recv:
+ // look up channel before calling serveRecv to
+ // avoid a lock around client.chans.
+ if nch := client.newChan(hdr, Send, req.Name, req.Size, req.Count); nch != nil {
+ go client.serveRecv(nch, *hdr, req.Count)
+ }
+ case Send:
+ client.newChan(hdr, Recv, req.Name, req.Size, req.Count)
+ // The actual sends will have payload type payData.
+ // TODO: manage the count?
+ default:
+ error.Error = "request: can't handle channel direction"
+ expLog(error.Error, req.Dir)
+ client.encode(hdr, payError, error)
+ }
+ case payData:
+ client.serveSend(*hdr)
+ case payClosed:
+ client.serveClosed(*hdr)
+ case payAck:
+ client.mu.Lock()
+ if client.ackNum != hdr.SeqNum-1 {
+ // Since the sequence number is incremented and the message is sent
+ // in a single instance of locking client.mu, the messages are guaranteed
+ // to be sent in order. Therefore receipt of acknowledgement N means
+ // all messages <=N have been seen by the recipient. We check anyway.
+ expLog("sequence out of order:", client.ackNum, hdr.SeqNum)
+ }
+ if client.ackNum < hdr.SeqNum { // If there has been an error, don't back up the count.
+ client.ackNum = hdr.SeqNum
+ }
+ client.mu.Unlock()
+ case payAckSend:
+ if nch := client.getChan(hdr, Send); nch != nil {
+ nch.acked()
+ }
+ default:
+ log.Fatal("netchan export: unknown payload type", hdr.PayloadType)
+ }
+ }
+ client.exp.delClient(client)
+}
+
+// Send all the data on a single channel to a client asking for a Recv.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) {
+ for {
+ val, ok := nch.recv()
+ if !ok {
+ if err := client.encode(&hdr, payClosed, nil); err != nil {
+ expLog("error encoding server closed message:", err)
+ }
+ break
+ }
+ // We hold the lock during transmission to guarantee messages are
+ // sent in sequence number order. Also, we increment first so the
+ // value of client.SeqNum is the value of the highest used sequence
+ // number, not one beyond.
+ client.mu.Lock()
+ client.seqNum++
+ hdr.SeqNum = client.seqNum
+ client.seqLock.Lock() // guarantee ordering of messages
+ client.mu.Unlock()
+ err := client.encode(&hdr, payData, val.Interface())
+ client.seqLock.Unlock()
+ if err != nil {
+ expLog("error encoding client response:", err)
+ client.sendError(&hdr, err.String())
+ break
+ }
+ // Negative count means run forever.
+ if count >= 0 {
+ if count--; count <= 0 {
+ break
+ }
+ }
+ }
+}
+
+// Receive and deliver locally one item from a client asking for a Send
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveSend(hdr header) {
+ nch := client.getChan(&hdr, Recv)
+ if nch == nil {
+ return
+ }
+ // Create a new value for each received item.
+ val := reflect.Zero(nch.ch.Type().Elem())
+ if err := client.decode(val); err != nil {
+ expLog("value decode:", err, "; type ", nch.ch.Type())
+ return
+ }
+ nch.send(val)
+}
+
+// Report that client has closed the channel that is sending to us.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveClosed(hdr header) {
+ nch := client.getChan(&hdr, Recv)
+ if nch == nil {
+ return
+ }
+ nch.close()
+}
+
+func (client *expClient) unackedCount() int64 {
+ client.mu.Lock()
+ n := client.seqNum - client.ackNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) seq() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) ack() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+// Serve waits for incoming connections on the listener
+// and serves the Exporter's channels on each.
+// It blocks until the listener is closed.
+func (exp *Exporter) Serve(listener net.Listener) {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ expLog("listen:", err)
+ break
+ }
+ go exp.ServeConn(conn)
+ }
+}
+
+// ServeConn exports the Exporter's channels on conn.
+// It blocks until the connection is terminated.
+func (exp *Exporter) ServeConn(conn io.ReadWriter) {
+ exp.addClient(conn).run()
+}
+
+// NewExporter creates a new Exporter that exports a set of channels.
+func NewExporter() *Exporter {
+ e := &Exporter{
+ clientSet: &clientSet{
+ names: make(map[string]*chanDir),
+ clients: make(map[unackedCounter]bool),
+ },
+ }
+ return e
+}
+
+// ListenAndServe exports the exporter's channels through the
+// given network and local address defined as in net.Listen.
+func (exp *Exporter) ListenAndServe(network, localaddr string) os.Error {
+ listener, err := net.Listen(network, localaddr)
+ if err != nil {
+ return err
+ }
+ go exp.Serve(listener)
+ return nil
+}
+
+// addClient creates a new expClient and records its existence
+func (exp *Exporter) addClient(conn io.ReadWriter) *expClient {
+ client := newClient(exp, conn)
+ exp.mu.Lock()
+ exp.clients[client] = true
+ exp.mu.Unlock()
+ return client
+}
+
+// delClient forgets the client existed
+func (exp *Exporter) delClient(client *expClient) {
+ exp.mu.Lock()
+ exp.clients[client] = false, false
+ exp.mu.Unlock()
+}
+
+// Drain waits until all messages sent from this exporter/importer, including
+// those not yet sent to any client and possibly including those sent while
+// Drain was executing, have been received by the importer. In short, it
+// waits until all the exporter's messages have been received by a client.
+// If the timeout (measured in nanoseconds) is positive and Drain takes
+// longer than that to complete, an error is returned.
+func (exp *Exporter) Drain(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.drain(timeout)
+}
+
+// Sync waits until all clients of the exporter have received the messages
+// that were sent at the time Sync was invoked. Unlike Drain, it does not
+// wait for messages sent while it is running or messages that have not been
+// dispatched to any client. If the timeout (measured in nanoseconds) is
+// positive and Sync takes longer than that to complete, an error is
+// returned.
+func (exp *Exporter) Sync(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.sync(timeout)
+}
+
+func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) {
+ chanType := reflect.Typeof(chT)
+ if chanType.Kind() != reflect.Chan {
+ return reflect.Value{}, os.ErrorString("not a channel")
+ }
+ if dir != Send && dir != Recv {
+ return reflect.Value{}, os.ErrorString("unknown channel direction")
+ }
+ switch chanType.ChanDir() {
+ case reflect.BothDir:
+ case reflect.SendDir:
+ if dir != Recv {
+ return reflect.Value{}, os.ErrorString("to import/export with Send, must provide <-chan")
+ }
+ case reflect.RecvDir:
+ if dir != Send {
+ return reflect.Value{}, os.ErrorString("to import/export with Recv, must provide chan<-")
+ }
+ }
+ return reflect.NewValue(chT), nil
+}
+
+// Export exports a channel of a given type and specified direction. The
+// channel to be exported is provided in the call and may be of arbitrary
+// channel type.
+// Despite the literal signature, the effective signature is
+// Export(name string, chT chan T, dir Dir)
+func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
+ ch, err := checkChan(chT, dir)
+ if err != nil {
+ return err
+ }
+ exp.mu.Lock()
+ defer exp.mu.Unlock()
+ _, present := exp.names[name]
+ if present {
+ return os.ErrorString("channel name already being exported:" + name)
+ }
+ exp.names[name] = &chanDir{ch, dir}
+ return nil
+}
+
+// Hangup disassociates the named channel from the Exporter and closes
+// the channel. Messages in flight for the channel may be dropped.
+func (exp *Exporter) Hangup(name string) os.Error {
+ exp.mu.Lock()
+ chDir, ok := exp.names[name]
+ if ok {
+ exp.names[name] = nil, false
+ }
+ // TODO drop all instances of channel from client sets
+ exp.mu.Unlock()
+ if !ok {
+ return os.ErrorString("netchan export: hangup: no such channel: " + name)
+ }
+ chDir.ch.Close()
+ return nil
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "utf8"
+)
+
+// Some constants in the form of bytes, to avoid string overhead.
+// Needlessly fastidious, I suppose.
+var (
+ commaSpaceBytes = []byte(", ")
+ nilAngleBytes = []byte("<nil>")
+ nilParenBytes = []byte("(nil)")
+ nilBytes = []byte("nil")
+ mapBytes = []byte("map[")
+ missingBytes = []byte("(MISSING)")
+ extraBytes = []byte("%!(EXTRA ")
+ irparenBytes = []byte("i)")
+ bytesBytes = []byte("[]byte{")
+ widthBytes = []byte("%!(BADWIDTH)")
+ precBytes = []byte("%!(BADPREC)")
+ noVerbBytes = []byte("%!(NOVERB)")
+)
+
+// State represents the printer state passed to custom formatters.
+// It provides access to the io.Writer interface plus information about
+// the flags and options for the operand's format specifier.
+type State interface {
+ // Write is the function to call to emit formatted output to be printed.
+ Write(b []byte) (ret int, err os.Error)
+ // Width returns the value of the width option and whether it has been set.
+ Width() (wid int, ok bool)
+ // Precision returns the value of the precision option and whether it has been set.
+ Precision() (prec int, ok bool)
+
+ // Flag returns whether the flag c, a character, has been set.
+ Flag(int) bool
+}
+
+// Formatter is the interface implemented by values with a custom formatter.
+// The implementation of Format may call Sprintf or Fprintf(f) etc.
+// to generate its output.
+type Formatter interface {
+ Format(f State, c int)
+}
+
+// Stringer is implemented by any value that has a String method(),
+// which defines the ``native'' format for that value.
+// The String method is used to print values passed as an operand
+// to a %s or %v format or to an unformatted printer such as Print.
+type Stringer interface {
+ String() string
+}
+
+// GoStringer is implemented by any value that has a GoString() method,
+// which defines the Go syntax for that value.
+// The GoString method is used to print values passed as an operand
+// to a %#v format.
+type GoStringer interface {
+ GoString() string
+}
+
+type pp struct {
+ n int
+ buf bytes.Buffer
+ runeBuf [utf8.UTFMax]byte
+ fmt fmt
+}
+
+// A cache holds a set of reusable objects.
+// The buffered channel holds the currently available objects.
+// If more are needed, the cache creates them by calling new.
+type cache struct {
+ saved chan interface{}
+ new func() interface{}
+}
+
+func (c *cache) put(x interface{}) {
+ select {
+ case c.saved <- x:
+ // saved in cache
+ default:
+ // discard
+ }
+}
+
+func (c *cache) get() interface{} {
+ select {
+ case x := <-c.saved:
+ return x // reused from cache
+ default:
+ return c.new()
+ }
+ panic("not reached")
+}
+
+func newCache(f func() interface{}) *cache {
+ return &cache{make(chan interface{}, 100), f}
+}
+
+var ppFree = newCache(func() interface{} { return new(pp) })
+
+// Allocate a new pp struct or grab a cached one.
+func newPrinter() *pp {
+ p := ppFree.get().(*pp)
+ p.fmt.init(&p.buf)
+ return p
+}
+
+// Save used pp structs in ppFree; avoids an allocation per invocation.
+func (p *pp) free() {
+ // Don't hold on to pp structs with large buffers.
+ if cap(p.buf.Bytes()) > 1024 {
+ return
+ }
+ p.buf.Reset()
+ ppFree.put(p)
+}
+
+func (p *pp) Width() (wid int, ok bool) { return p.fmt.wid, p.fmt.widPresent }
+
+func (p *pp) Precision() (prec int, ok bool) { return p.fmt.prec, p.fmt.precPresent }
+
+func (p *pp) Flag(b int) bool {
+ switch b {
+ case '-':
+ return p.fmt.minus
+ case '+':
+ return p.fmt.plus
+ case '#':
+ return p.fmt.sharp
+ case ' ':
+ return p.fmt.space
+ case '0':
+ return p.fmt.zero
+ }
+ return false
+}
+
+func (p *pp) add(c int) {
+ p.buf.WriteRune(c)
+}
+
+// Implement Write so we can call Fprintf on a pp (through State), for
+// recursive use in custom verbs.
+func (p *pp) Write(b []byte) (ret int, err os.Error) {
+ return p.buf.Write(b)
+}
+
+// These routines end in 'f' and take a format string.
+
+// Fprintf formats according to a format specifier and writes to w.
+// It returns the number of bytes written and any write error encountered.
+func Fprintf(w io.Writer, format string, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrintf(format, a)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Printf formats according to a format specifier and writes to standard output.
+// It returns the number of bytes written and any write error encountered.
+func Printf(format string, a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprintf(os.Stdout, format, a...)
+ return n, errno
+}
+
+// Sprintf formats according to a format specifier and returns the resulting string.
+func Sprintf(format string, a ...interface{}) string {
+ p := newPrinter()
+ p.doPrintf(format, a)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// Errorf formats according to a format specifier and returns the string
+// converted to an os.ErrorString, which satisfies the os.Error interface.
+func Errorf(format string, a ...interface{}) os.Error {
+ return os.ErrorString(Sprintf(format, a...))
+}
+
+// These routines do not take a format string
+
+// Fprint formats using the default formats for its operands and writes to w.
+// Spaces are added between operands when neither is a string.
+// It returns the number of bytes written and any write error encountered.
+func Fprint(w io.Writer, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrint(a, false, false)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Print formats using the default formats for its operands and writes to standard output.
+// Spaces are added between operands when neither is a string.
+// It returns the number of bytes written and any write error encountered.
+func Print(a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprint(os.Stdout, a...)
+ return n, errno
+}
+
+// Sprint formats using the default formats for its operands and returns the resulting string.
+// Spaces are added between operands when neither is a string.
+func Sprint(a ...interface{}) string {
+ p := newPrinter()
+ p.doPrint(a, false, false)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// These routines end in 'ln', do not take a format string,
+// always add spaces between operands, and add a newline
+// after the last operand.
+
+// Fprintln formats using the default formats for its operands and writes to w.
+// Spaces are always added between operands and a newline is appended.
+// It returns the number of bytes written and any write error encountered.
+func Fprintln(w io.Writer, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrint(a, true, true)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Println formats using the default formats for its operands and writes to standard output.
+// Spaces are always added between operands and a newline is appended.
+// It returns the number of bytes written and any write error encountered.
+func Println(a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprintln(os.Stdout, a...)
+ return n, errno
+}
+
+// Sprintln formats using the default formats for its operands and returns the resulting string.
+// Spaces are always added between operands and a newline is appended.
+func Sprintln(a ...interface{}) string {
+ p := newPrinter()
+ p.doPrint(a, true, true)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+
+// Get the i'th arg of the struct value.
+// If the arg itself is an interface, return a value for
+// the thing inside the interface, not the interface itself.
+func getField(v *reflect.StructValue, i int) reflect.Value {
+ val := v.Field(i)
+ if i, ok := val.(*reflect.InterfaceValue); ok {
+ if inter := i.Interface(); inter != nil {
+ return reflect.NewValue(inter)
+ }
+ }
+ return val
+}
+
+// Convert ASCII to integer. n is 0 (and got is false) if no number present.
+func parsenum(s string, start, end int) (num int, isnum bool, newi int) {
+ if start >= end {
+ return 0, false, end
+ }
+ for newi = start; newi < end && '0' <= s[newi] && s[newi] <= '9'; newi++ {
+ num = num*10 + int(s[newi]-'0')
+ isnum = true
+ }
+ return
+}
+
+func (p *pp) unknownType(v interface{}) {
+ if v == nil {
+ p.buf.Write(nilAngleBytes)
+ return
+ }
+ p.buf.WriteByte('?')
+ p.buf.WriteString(reflect.Typeof(v).String())
+ p.buf.WriteByte('?')
+}
+
+func (p *pp) badVerb(verb int, val interface{}) {
+ p.add('%')
+ p.add('!')
+ p.add(verb)
+ p.add('(')
+ if val == nil {
+ p.buf.Write(nilAngleBytes)
+ } else {
+ p.buf.WriteString(reflect.Typeof(val).String())
+ p.add('=')
+ p.printField(val, 'v', false, false, 0)
+ }
+ p.add(')')
+}
+
+func (p *pp) fmtBool(v bool, verb int, value interface{}) {
+ switch verb {
+ case 't', 'v':
+ p.fmt.fmt_boolean(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+// fmtC formats a rune for the 'c' format.
+func (p *pp) fmtC(c int64) {
+ rune := int(c) // Check for overflow.
+ if int64(rune) != c {
+ rune = utf8.RuneError
+ }
+ w := utf8.EncodeRune(p.runeBuf[0:utf8.UTFMax], rune)
+ p.fmt.pad(p.runeBuf[0:w])
+}
+
+func (p *pp) fmtInt64(v int64, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.integer(v, 2, signed, ldigits)
+ case 'c':
+ p.fmtC(v)
+ case 'd', 'v':
+ p.fmt.integer(v, 10, signed, ldigits)
+ case 'o':
+ p.fmt.integer(v, 8, signed, ldigits)
+ case 'x':
+ p.fmt.integer(v, 16, signed, ldigits)
+ case 'U':
+ p.fmtUnicode(v)
+ case 'X':
+ p.fmt.integer(v, 16, signed, udigits)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+// fmt0x64 formats a uint64 in hexadecimal and prefixes it with 0x or
+// not, as requested, by temporarily setting the sharp flag.
+func (p *pp) fmt0x64(v uint64, leading0x bool) {
+ sharp := p.fmt.sharp
+ p.fmt.sharp = leading0x
+ p.fmt.integer(int64(v), 16, unsigned, ldigits)
+ p.fmt.sharp = sharp
+}
+
+// fmtUnicode formats a uint64 in U+1234 form by
+// temporarily turning on the unicode flag and tweaking the precision.
+func (p *pp) fmtUnicode(v int64) {
+ precPresent := p.fmt.precPresent
+ prec := p.fmt.prec
+ if !precPresent {
+ // If prec is already set, leave it alone; otherwise 4 is minimum.
+ p.fmt.prec = 4
+ p.fmt.precPresent = true
+ }
+ p.fmt.unicode = true // turn on U+
+ p.fmt.integer(int64(v), 16, unsigned, udigits)
+ p.fmt.unicode = false
+ p.fmt.prec = prec
+ p.fmt.precPresent = precPresent
+}
+
+func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.integer(int64(v), 2, unsigned, ldigits)
+ case 'c':
+ p.fmtC(int64(v))
+ case 'd':
+ p.fmt.integer(int64(v), 10, unsigned, ldigits)
+ case 'v':
+ if goSyntax {
+ p.fmt0x64(v, true)
+ } else {
+ p.fmt.integer(int64(v), 10, unsigned, ldigits)
+ }
+ case 'o':
+ p.fmt.integer(int64(v), 8, unsigned, ldigits)
+ case 'x':
+ p.fmt.integer(int64(v), 16, unsigned, ldigits)
+ case 'X':
+ p.fmt.integer(int64(v), 16, unsigned, udigits)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtFloat32(v float32, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.fmt_fb32(v)
+ case 'e':
+ p.fmt.fmt_e32(v)
+ case 'E':
+ p.fmt.fmt_E32(v)
+ case 'f':
+ p.fmt.fmt_f32(v)
+ case 'g', 'v':
+ p.fmt.fmt_g32(v)
+ case 'G':
+ p.fmt.fmt_G32(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtFloat64(v float64, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.fmt_fb64(v)
+ case 'e':
+ p.fmt.fmt_e64(v)
+ case 'E':
+ p.fmt.fmt_E64(v)
+ case 'f':
+ p.fmt.fmt_f64(v)
+ case 'g', 'v':
+ p.fmt.fmt_g64(v)
+ case 'G':
+ p.fmt.fmt_G64(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtComplex64(v complex64, verb int, value interface{}) {
+ switch verb {
+ case 'e', 'E', 'f', 'F', 'g', 'G':
+ p.fmt.fmt_c64(v, verb)
+ case 'v':
+ p.fmt.fmt_c64(v, 'g')
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtComplex128(v complex128, verb int, value interface{}) {
+ switch verb {
+ case 'e', 'E', 'f', 'F', 'g', 'G':
+ p.fmt.fmt_c128(v, verb)
+ case 'v':
+ p.fmt.fmt_c128(v, 'g')
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) {
+ switch verb {
+ case 'v':
+ if goSyntax {
+ p.fmt.fmt_q(v)
+ } else {
+ p.fmt.fmt_s(v)
+ }
+ case 's':
+ p.fmt.fmt_s(v)
+ case 'x':
+ p.fmt.fmt_sx(v)
+ case 'X':
+ p.fmt.fmt_sX(v)
+ case 'q':
+ p.fmt.fmt_q(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interface{}) {
+ if verb == 'v' || verb == 'd' {
+ if goSyntax {
+ p.buf.Write(bytesBytes)
+ } else {
+ p.buf.WriteByte('[')
+ }
+ for i, c := range v {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(c, 'v', p.fmt.plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ return
+ }
+ s := string(v)
+ switch verb {
+ case 's':
+ p.fmt.fmt_s(s)
+ case 'x':
+ p.fmt.fmt_sx(s)
+ case 'X':
+ p.fmt.fmt_sX(s)
+ case 'q':
+ p.fmt.fmt_q(s)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSyntax bool) {
+ var u uintptr
+ switch value.(type) {
+ case *reflect.ChanValue, *reflect.FuncValue, *reflect.MapValue, *reflect.PtrValue, *reflect.SliceValue, *reflect.UnsafePointerValue:
+ u = value.Pointer()
+ default:
+ p.badVerb(verb, field)
+ return
+ }
+ if goSyntax {
+ p.add('(')
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.add(')')
+ p.add('(')
+ if u == 0 {
+ p.buf.Write(nilBytes)
+ } else {
+ p.fmt0x64(uint64(u), true)
+ }
+ p.add(')')
+ } else {
+ p.fmt0x64(uint64(u), !p.fmt.sharp)
+ }
+}
+
+var (
+ intBits = reflect.Typeof(0).Bits()
+ floatBits = reflect.Typeof(0.0).Bits()
+ complexBits = reflect.Typeof(1i).Bits()
+ uintptrBits = reflect.Typeof(uintptr(0)).Bits()
+)
+
+func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) {
+ if field == nil {
+ if verb == 'T' || verb == 'v' {
+ p.buf.Write(nilAngleBytes)
+ } else {
+ p.badVerb(verb, field)
+ }
+ return false
+ }
+
+ // Special processing considerations.
+ // %T (the value's type) and %p (its address) are special; we always do them first.
+ switch verb {
+ case 'T':
+ p.printField(reflect.Typeof(field).String(), 's', false, false, 0)
+ return false
+ case 'p':
+ p.fmtPointer(field, reflect.NewValue(field), verb, goSyntax)
+ return false
+ }
+ // Is it a Formatter?
+ if formatter, ok := field.(Formatter); ok {
+ formatter.Format(p, verb)
+ return false // this value is not a string
+
+ }
+ // Must not touch flags before Formatter looks at them.
+ if plus {
+ p.fmt.plus = false
+ }
+ // If we're doing Go syntax and the field knows how to supply it, take care of it now.
+ if goSyntax {
+ p.fmt.sharp = false
+ if stringer, ok := field.(GoStringer); ok {
+ // Print the result of GoString unadorned.
+ p.fmtString(stringer.GoString(), 's', false, field)
+ return false // this value is not a string
+ }
+ } else {
+ // Is it a Stringer?
+ if stringer, ok := field.(Stringer); ok {
+ p.printField(stringer.String(), verb, plus, false, depth)
+ return false // this value is not a string
+ }
+ }
+
+ // Some types can be done without reflection.
+ switch f := field.(type) {
+ case bool:
+ p.fmtBool(f, verb, field)
+ return false
+ case float32:
+ p.fmtFloat32(f, verb, field)
+ return false
+ case float64:
+ p.fmtFloat64(f, verb, field)
+ return false
+ case complex64:
+ p.fmtComplex64(complex64(f), verb, field)
+ return false
+ case complex128:
+ p.fmtComplex128(f, verb, field)
+ return false
+ case int:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int8:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int16:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int32:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int64:
+ p.fmtInt64(f, verb, field)
+ return false
+ case uint:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint8:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint16:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint32:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint64:
+ p.fmtUint64(f, verb, goSyntax, field)
+ return false
+ case uintptr:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case string:
+ p.fmtString(f, verb, goSyntax, field)
+ return verb == 's' || verb == 'v'
+ case []byte:
+ p.fmtBytes(f, verb, goSyntax, depth, field)
+ return verb == 's'
+ }
+
+ // Need to use reflection
+ value := reflect.NewValue(field)
+
+BigSwitch:
+ switch f := value.(type) {
+ case *reflect.BoolValue:
+ p.fmtBool(f.Get(), verb, field)
+ case *reflect.IntValue:
+ p.fmtInt64(f.Get(), verb, field)
+ case *reflect.UintValue:
+ p.fmtUint64(uint64(f.Get()), verb, goSyntax, field)
+ case *reflect.FloatValue:
+ if f.Type().Size() == 4 {
+ p.fmtFloat32(float32(f.Get()), verb, field)
+ } else {
+ p.fmtFloat64(float64(f.Get()), verb, field)
+ }
+ case *reflect.ComplexValue:
+ if f.Type().Size() == 8 {
+ p.fmtComplex64(complex64(f.Get()), verb, field)
+ } else {
+ p.fmtComplex128(complex128(f.Get()), verb, field)
+ }
+ case *reflect.StringValue:
+ p.fmtString(f.Get(), verb, goSyntax, field)
+ case *reflect.MapValue:
+ if goSyntax {
+ p.buf.WriteString(f.Type().String())
+ p.buf.WriteByte('{')
+ } else {
+ p.buf.Write(mapBytes)
+ }
+ keys := f.Keys()
+ for i, key := range keys {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(key.Interface(), verb, plus, goSyntax, depth+1)
+ p.buf.WriteByte(':')
+ p.printField(f.Elem(key).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ case *reflect.StructValue:
+ if goSyntax {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ }
+ p.add('{')
+ v := f
+ t := v.Type().(*reflect.StructType)
+ for i := 0; i < v.NumField(); i++ {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ if plus || goSyntax {
+ if f := t.Field(i); f.Name != "" {
+ p.buf.WriteString(f.Name)
+ p.buf.WriteByte(':')
+ }
+ }
+ p.printField(getField(v, i).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ p.buf.WriteByte('}')
+ case *reflect.InterfaceValue:
+ value := f.Elem()
+ if value == nil {
+ if goSyntax {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.Write(nilParenBytes)
+ } else {
+ p.buf.Write(nilAngleBytes)
+ }
+ } else {
+ return p.printField(value.Interface(), verb, plus, goSyntax, depth+1)
+ }
+ case reflect.ArrayOrSliceValue:
+ // Byte slices are special.
+ if f.Type().(reflect.ArrayOrSliceType).Elem().Kind() == reflect.Uint8 {
+ // We know it's a slice of bytes, but we also know it does not have static type
+ // []byte, or it would have been caught above. Therefore we cannot convert
+ // it directly in the (slightly) obvious way: f.Interface().([]byte); it doesn't have
+ // that type, and we can't write an expression of the right type and do a
+ // conversion because we don't have a static way to write the right type.
+ // So we build a slice by hand. This is a rare case but it would be nice
+ // if reflection could help a little more.
+ bytes := make([]byte, f.Len())
+ for i := range bytes {
+ bytes[i] = byte(f.Elem(i).(*reflect.UintValue).Get())
+ }
+ p.fmtBytes(bytes, verb, goSyntax, depth, field)
+ return verb == 's'
+ }
+ if goSyntax {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.WriteByte('{')
+ } else {
+ p.buf.WriteByte('[')
+ }
+ for i := 0; i < f.Len(); i++ {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(f.Elem(i).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ case *reflect.PtrValue:
+ v := f.Get()
+ // pointer to array or slice or struct? ok at top level
+ // but not embedded (avoid loops)
+ if v != 0 && depth == 0 {
+ switch a := f.Elem().(type) {
+ case reflect.ArrayOrSliceValue:
+ p.buf.WriteByte('&')
+ p.printField(a.Interface(), verb, plus, goSyntax, depth+1)
+ break BigSwitch
+ case *reflect.StructValue:
+ p.buf.WriteByte('&')
+ p.printField(a.Interface(), verb, plus, goSyntax, depth+1)
+ break BigSwitch
+ }
+ }
+ if goSyntax {
+ p.buf.WriteByte('(')
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.WriteByte(')')
+ p.buf.WriteByte('(')
+ if v == 0 {
+ p.buf.Write(nilBytes)
+ } else {
+ p.fmt0x64(uint64(v), true)
+ }
+ p.buf.WriteByte(')')
+ break
+ }
+ if v == 0 {
+ p.buf.Write(nilAngleBytes)
+ break
+ }
+ p.fmt0x64(uint64(v), true)
+ case *reflect.ChanValue, *reflect.FuncValue, *reflect.UnsafePointerValue:
+ p.fmtPointer(field, value, verb, goSyntax)
+ default:
+ p.unknownType(f)
+ }
+ return false
+}
+
+// intFromArg gets the fieldnumth element of a. On return, isInt reports whether the argument has type int.
+func intFromArg(a []interface{}, end, i, fieldnum int) (num int, isInt bool, newi, newfieldnum int) {
+ newi, newfieldnum = end, fieldnum
+ if i < end && fieldnum < len(a) {
+ num, isInt = a[fieldnum].(int)
+ newi, newfieldnum = i+1, fieldnum+1
+ }
+ return
+}
+
+func (p *pp) doPrintf(format string, a []interface{}) {
+ end := len(format)
+ fieldnum := 0 // we process one field per non-trivial format
+ for i := 0; i < end; {
+ lasti := i
+ for i < end && format[i] != '%' {
+ i++
+ }
+ if i > lasti {
+ p.buf.WriteString(format[lasti:i])
+ }
+ if i >= end {
+ // done processing format string
+ break
+ }
+
+ // Process one verb
+ i++
+ // flags and widths
+ p.fmt.clearflags()
+ F:
+ for ; i < end; i++ {
+ switch format[i] {
+ case '#':
+ p.fmt.sharp = true
+ case '0':
+ p.fmt.zero = true
+ case '+':
+ p.fmt.plus = true
+ case '-':
+ p.fmt.minus = true
+ case ' ':
+ p.fmt.space = true
+ default:
+ break F
+ }
+ }
+ // do we have width?
+ if i < end && format[i] == '*' {
+ p.fmt.wid, p.fmt.widPresent, i, fieldnum = intFromArg(a, end, i, fieldnum)
+ if !p.fmt.widPresent {
+ p.buf.Write(widthBytes)
+ }
+ } else {
+ p.fmt.wid, p.fmt.widPresent, i = parsenum(format, i, end)
+ }
+ // do we have precision?
+ if i < end && format[i] == '.' {
+ if format[i+1] == '*' {
+ p.fmt.prec, p.fmt.precPresent, i, fieldnum = intFromArg(a, end, i+1, fieldnum)
+ if !p.fmt.precPresent {
+ p.buf.Write(precBytes)
+ }
+ } else {
+ p.fmt.prec, p.fmt.precPresent, i = parsenum(format, i+1, end)
+ }
+ }
+ if i >= end {
+ p.buf.Write(noVerbBytes)
+ continue
+ }
+ c, w := utf8.DecodeRuneInString(format[i:])
+ i += w
+ // percent is special - absorbs no operand
+ if c == '%' {
+ p.buf.WriteByte('%') // We ignore width and prec.
+ continue
+ }
+ if fieldnum >= len(a) { // out of operands
+ p.buf.WriteByte('%')
+ p.add(c)
+ p.buf.Write(missingBytes)
+ continue
+ }
+ field := a[fieldnum]
+ fieldnum++
+
+ goSyntax := c == 'v' && p.fmt.sharp
+ plus := c == 'v' && p.fmt.plus
+ p.printField(field, c, plus, goSyntax, 0)
+ }
+
+ if fieldnum < len(a) {
+ p.buf.Write(extraBytes)
+ for ; fieldnum < len(a); fieldnum++ {
+ field := a[fieldnum]
+ if field != nil {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.WriteByte('=')
+ }
+ p.printField(field, 'v', false, false, 0)
+ if fieldnum+1 < len(a) {
+ p.buf.Write(commaSpaceBytes)
+ }
+ }
+ p.buf.WriteByte(')')
+ }
+}
+
+func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) {
+ prevString := false
+ for fieldnum := 0; fieldnum < len(a); fieldnum++ {
+ p.fmt.clearflags()
+ // always add spaces if we're doing println
+ field := a[fieldnum]
+ if fieldnum > 0 {
+ isString := field != nil && reflect.Typeof(field).Kind() == reflect.String
+ if addspace || !isString && !prevString {
+ p.buf.WriteByte(' ')
+ }
+ }
+ prevString = p.printField(field, 'v', false, false, 0)
+ }
+ if addnewline {
+ p.buf.WriteByte('\n')
+ }
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "reflect"
+ "utf8"
+)
+
+// Some constants in the form of bytes, to avoid string overhead.
+// Needlessly fastidious, I suppose.
+var (
+ commaSpaceBytes = []byte(", ")
+ nilAngleBytes = []byte("<nil>")
+ nilParenBytes = []byte("(nil)")
+ nilBytes = []byte("nil")
+ mapBytes = []byte("map[")
+ missingBytes = []byte("(MISSING)")
+ extraBytes = []byte("%!(EXTRA ")
+ irparenBytes = []byte("i)")
+ bytesBytes = []byte("[]byte{")
+ widthBytes = []byte("%!(BADWIDTH)")
+ precBytes = []byte("%!(BADPREC)")
+ noVerbBytes = []byte("%!(NOVERB)")
+)
+
+// State represents the printer state passed to custom formatters.
+// It provides access to the io.Writer interface plus information about
+// the flags and options for the operand's format specifier.
+type State interface {
+ // Write is the function to call to emit formatted output to be printed.
+ Write(b []byte) (ret int, err os.Error)
+ // Width returns the value of the width option and whether it has been set.
+ Width() (wid int, ok bool)
+ // Precision returns the value of the precision option and whether it has been set.
+ Precision() (prec int, ok bool)
+
+ // Flag returns whether the flag c, a character, has been set.
+ Flag(int) bool
+}
+
+// Formatter is the interface implemented by values with a custom formatter.
+// The implementation of Format may call Sprintf or Fprintf(f) etc.
+// to generate its output.
+type Formatter interface {
+ Format(f State, c int)
+}
+
+// Stringer is implemented by any value that has a String method(),
+// which defines the ``native'' format for that value.
+// The String method is used to print values passed as an operand
+// to a %s or %v format or to an unformatted printer such as Print.
+type Stringer interface {
+ String() string
+}
+
+// GoStringer is implemented by any value that has a GoString() method,
+// which defines the Go syntax for that value.
+// The GoString method is used to print values passed as an operand
+// to a %#v format.
+type GoStringer interface {
+ GoString() string
+}
+
+type pp struct {
+ n int
+ buf bytes.Buffer
+ runeBuf [utf8.UTFMax]byte
+ fmt fmt
+}
+
+// A cache holds a set of reusable objects.
+// The buffered channel holds the currently available objects.
+// If more are needed, the cache creates them by calling new.
+type cache struct {
+ saved chan interface{}
+ new func() interface{}
+}
+
+func (c *cache) put(x interface{}) {
+ select {
+ case c.saved <- x:
+ // saved in cache
+ default:
+ // discard
+ }
+}
+
+func (c *cache) get() interface{} {
+ select {
+ case x := <-c.saved:
+ return x // reused from cache
+ default:
+ return c.new()
+ }
+ panic("not reached")
+}
+
+func newCache(f func() interface{}) *cache {
+ return &cache{make(chan interface{}, 100), f}
+}
+
+var ppFree = newCache(func() interface{} { return new(pp) })
+
+// Allocate a new pp struct or grab a cached one.
+func newPrinter() *pp {
+ p := ppFree.get().(*pp)
+ p.fmt.init(&p.buf)
+ return p
+}
+
+// Save used pp structs in ppFree; avoids an allocation per invocation.
+func (p *pp) free() {
+ // Don't hold on to pp structs with large buffers.
+ if cap(p.buf.Bytes()) > 1024 {
+ return
+ }
+ p.buf.Reset()
+ ppFree.put(p)
+}
+
+func (p *pp) Width() (wid int, ok bool) { return p.fmt.wid, p.fmt.widPresent }
+
+func (p *pp) Precision() (prec int, ok bool) { return p.fmt.prec, p.fmt.precPresent }
+
+func (p *pp) Flag(b int) bool {
+ switch b {
+ case '-':
+ return p.fmt.minus
+ case '+':
+ return p.fmt.plus
+ case '#':
+ return p.fmt.sharp
+ case ' ':
+ return p.fmt.space
+ case '0':
+ return p.fmt.zero
+ }
+ return false
+}
+
+func (p *pp) add(c int) {
+ p.buf.WriteRune(c)
+}
+
+// Implement Write so we can call Fprintf on a pp (through State), for
+// recursive use in custom verbs.
+func (p *pp) Write(b []byte) (ret int, err os.Error) {
+ return p.buf.Write(b)
+}
+
+// These routines end in 'f' and take a format string.
+
+// Fprintf formats according to a format specifier and writes to w.
+// It returns the number of bytes written and any write error encountered.
+func Fprintf(w io.Writer, format string, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrintf(format, a)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Printf formats according to a format specifier and writes to standard output.
+// It returns the number of bytes written and any write error encountered.
+func Printf(format string, a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprintf(os.Stdout, format, a...)
+ return n, errno
+}
+
+// Sprintf formats according to a format specifier and returns the resulting string.
+func Sprintf(format string, a ...interface{}) string {
+ p := newPrinter()
+ p.doPrintf(format, a)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// Errorf formats according to a format specifier and returns the string
+// converted to an os.ErrorString, which satisfies the os.Error interface.
+func Errorf(format string, a ...interface{}) os.Error {
+ return os.ErrorString(Sprintf(format, a...))
+}
+
+// These routines do not take a format string
+
+// Fprint formats using the default formats for its operands and writes to w.
+// Spaces are added between operands when neither is a string.
+// It returns the number of bytes written and any write error encountered.
+func Fprint(w io.Writer, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrint(a, false, false)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Print formats using the default formats for its operands and writes to standard output.
+// Spaces are added between operands when neither is a string.
+// It returns the number of bytes written and any write error encountered.
+func Print(a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprint(os.Stdout, a...)
+ return n, errno
+}
+
+// Sprint formats using the default formats for its operands and returns the resulting string.
+// Spaces are added between operands when neither is a string.
+func Sprint(a ...interface{}) string {
+ p := newPrinter()
+ p.doPrint(a, false, false)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+// These routines end in 'ln', do not take a format string,
+// always add spaces between operands, and add a newline
+// after the last operand.
+
+// Fprintln formats using the default formats for its operands and writes to w.
+// Spaces are always added between operands and a newline is appended.
+// It returns the number of bytes written and any write error encountered.
+func Fprintln(w io.Writer, a ...interface{}) (n int, error os.Error) {
+ p := newPrinter()
+ p.doPrint(a, true, true)
+ n64, error := p.buf.WriteTo(w)
+ p.free()
+ return int(n64), error
+}
+
+// Println formats using the default formats for its operands and writes to standard output.
+// Spaces are always added between operands and a newline is appended.
+// It returns the number of bytes written and any write error encountered.
+func Println(a ...interface{}) (n int, errno os.Error) {
+ n, errno = Fprintln(os.Stdout, a...)
+ return n, errno
+}
+
+// Sprintln formats using the default formats for its operands and returns the resulting string.
+// Spaces are always added between operands and a newline is appended.
+func Sprintln(a ...interface{}) string {
+ p := newPrinter()
+ p.doPrint(a, true, true)
+ s := p.buf.String()
+ p.free()
+ return s
+}
+
+
+// Get the i'th arg of the struct value.
+// If the arg itself is an interface, return a value for
+// the thing inside the interface, not the interface itself.
+func getField(v reflect.Value, i int) reflect.Value {
+ val := v.Field(i)
+ if i := val; i.Kind() == reflect.Interface {
+ if inter := i.Interface(); inter != nil {
+ return reflect.NewValue(inter)
+ }
+ }
+ return val
+}
+
+// Convert ASCII to integer. n is 0 (and got is false) if no number present.
+func parsenum(s string, start, end int) (num int, isnum bool, newi int) {
+ if start >= end {
+ return 0, false, end
+ }
+ for newi = start; newi < end && '0' <= s[newi] && s[newi] <= '9'; newi++ {
+ num = num*10 + int(s[newi]-'0')
+ isnum = true
+ }
+ return
+}
+
+func (p *pp) unknownType(v interface{}) {
+ if v == nil {
+ p.buf.Write(nilAngleBytes)
+ return
+ }
+ p.buf.WriteByte('?')
+ p.buf.WriteString(reflect.Typeof(v).String())
+ p.buf.WriteByte('?')
+}
+
+func (p *pp) badVerb(verb int, val interface{}) {
+ p.add('%')
+ p.add('!')
+ p.add(verb)
+ p.add('(')
+ if val == nil {
+ p.buf.Write(nilAngleBytes)
+ } else {
+ p.buf.WriteString(reflect.Typeof(val).String())
+ p.add('=')
+ p.printField(val, 'v', false, false, 0)
+ }
+ p.add(')')
+}
+
+func (p *pp) fmtBool(v bool, verb int, value interface{}) {
+ switch verb {
+ case 't', 'v':
+ p.fmt.fmt_boolean(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+// fmtC formats a rune for the 'c' format.
+func (p *pp) fmtC(c int64) {
+ rune := int(c) // Check for overflow.
+ if int64(rune) != c {
+ rune = utf8.RuneError
+ }
+ w := utf8.EncodeRune(p.runeBuf[0:utf8.UTFMax], rune)
+ p.fmt.pad(p.runeBuf[0:w])
+}
+
+func (p *pp) fmtInt64(v int64, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.integer(v, 2, signed, ldigits)
+ case 'c':
+ p.fmtC(v)
+ case 'd', 'v':
+ p.fmt.integer(v, 10, signed, ldigits)
+ case 'o':
+ p.fmt.integer(v, 8, signed, ldigits)
+ case 'x':
+ p.fmt.integer(v, 16, signed, ldigits)
+ case 'U':
+ p.fmtUnicode(v)
+ case 'X':
+ p.fmt.integer(v, 16, signed, udigits)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+// fmt0x64 formats a uint64 in hexadecimal and prefixes it with 0x or
+// not, as requested, by temporarily setting the sharp flag.
+func (p *pp) fmt0x64(v uint64, leading0x bool) {
+ sharp := p.fmt.sharp
+ p.fmt.sharp = leading0x
+ p.fmt.integer(int64(v), 16, unsigned, ldigits)
+ p.fmt.sharp = sharp
+}
+
+// fmtUnicode formats a uint64 in U+1234 form by
+// temporarily turning on the unicode flag and tweaking the precision.
+func (p *pp) fmtUnicode(v int64) {
+ precPresent := p.fmt.precPresent
+ prec := p.fmt.prec
+ if !precPresent {
+ // If prec is already set, leave it alone; otherwise 4 is minimum.
+ p.fmt.prec = 4
+ p.fmt.precPresent = true
+ }
+ p.fmt.unicode = true // turn on U+
+ p.fmt.integer(int64(v), 16, unsigned, udigits)
+ p.fmt.unicode = false
+ p.fmt.prec = prec
+ p.fmt.precPresent = precPresent
+}
+
+func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.integer(int64(v), 2, unsigned, ldigits)
+ case 'c':
+ p.fmtC(int64(v))
+ case 'd':
+ p.fmt.integer(int64(v), 10, unsigned, ldigits)
+ case 'v':
+ if goSyntax {
+ p.fmt0x64(v, true)
+ } else {
+ p.fmt.integer(int64(v), 10, unsigned, ldigits)
+ }
+ case 'o':
+ p.fmt.integer(int64(v), 8, unsigned, ldigits)
+ case 'x':
+ p.fmt.integer(int64(v), 16, unsigned, ldigits)
+ case 'X':
+ p.fmt.integer(int64(v), 16, unsigned, udigits)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtFloat32(v float32, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.fmt_fb32(v)
+ case 'e':
+ p.fmt.fmt_e32(v)
+ case 'E':
+ p.fmt.fmt_E32(v)
+ case 'f':
+ p.fmt.fmt_f32(v)
+ case 'g', 'v':
+ p.fmt.fmt_g32(v)
+ case 'G':
+ p.fmt.fmt_G32(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtFloat64(v float64, verb int, value interface{}) {
+ switch verb {
+ case 'b':
+ p.fmt.fmt_fb64(v)
+ case 'e':
+ p.fmt.fmt_e64(v)
+ case 'E':
+ p.fmt.fmt_E64(v)
+ case 'f':
+ p.fmt.fmt_f64(v)
+ case 'g', 'v':
+ p.fmt.fmt_g64(v)
+ case 'G':
+ p.fmt.fmt_G64(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtComplex64(v complex64, verb int, value interface{}) {
+ switch verb {
+ case 'e', 'E', 'f', 'F', 'g', 'G':
+ p.fmt.fmt_c64(v, verb)
+ case 'v':
+ p.fmt.fmt_c64(v, 'g')
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtComplex128(v complex128, verb int, value interface{}) {
+ switch verb {
+ case 'e', 'E', 'f', 'F', 'g', 'G':
+ p.fmt.fmt_c128(v, verb)
+ case 'v':
+ p.fmt.fmt_c128(v, 'g')
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) {
+ switch verb {
+ case 'v':
+ if goSyntax {
+ p.fmt.fmt_q(v)
+ } else {
+ p.fmt.fmt_s(v)
+ }
+ case 's':
+ p.fmt.fmt_s(v)
+ case 'x':
+ p.fmt.fmt_sx(v)
+ case 'X':
+ p.fmt.fmt_sX(v)
+ case 'q':
+ p.fmt.fmt_q(v)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interface{}) {
+ if verb == 'v' || verb == 'd' {
+ if goSyntax {
+ p.buf.Write(bytesBytes)
+ } else {
+ p.buf.WriteByte('[')
+ }
+ for i, c := range v {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(c, 'v', p.fmt.plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ return
+ }
+ s := string(v)
+ switch verb {
+ case 's':
+ p.fmt.fmt_s(s)
+ case 'x':
+ p.fmt.fmt_sx(s)
+ case 'X':
+ p.fmt.fmt_sX(s)
+ case 'q':
+ p.fmt.fmt_q(s)
+ default:
+ p.badVerb(verb, value)
+ }
+}
+
+func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSyntax bool) {
+ var u uintptr
+ switch value.Kind() {
+ case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
+ u = value.Pointer()
+ default:
+ p.badVerb(verb, field)
+ return
+ }
+ if goSyntax {
+ p.add('(')
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.add(')')
+ p.add('(')
+ if u == 0 {
+ p.buf.Write(nilBytes)
+ } else {
+ p.fmt0x64(uint64(u), true)
+ }
+ p.add(')')
+ } else {
+ p.fmt0x64(uint64(u), !p.fmt.sharp)
+ }
+}
+
+var (
+ intBits = reflect.Typeof(0).Bits()
+ floatBits = reflect.Typeof(0.0).Bits()
+ complexBits = reflect.Typeof(1i).Bits()
+ uintptrBits = reflect.Typeof(uintptr(0)).Bits()
+)
+
+func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) {
+ if field == nil {
+ if verb == 'T' || verb == 'v' {
+ p.buf.Write(nilAngleBytes)
+ } else {
+ p.badVerb(verb, field)
+ }
+ return false
+ }
+
+ // Special processing considerations.
+ // %T (the value's type) and %p (its address) are special; we always do them first.
+ switch verb {
+ case 'T':
+ p.printField(reflect.Typeof(field).String(), 's', false, false, 0)
+ return false
+ case 'p':
+ p.fmtPointer(field, reflect.NewValue(field), verb, goSyntax)
+ return false
+ }
+ // Is it a Formatter?
+ if formatter, ok := field.(Formatter); ok {
+ formatter.Format(p, verb)
+ return false // this value is not a string
+
+ }
+ // Must not touch flags before Formatter looks at them.
+ if plus {
+ p.fmt.plus = false
+ }
+ // If we're doing Go syntax and the field knows how to supply it, take care of it now.
+ if goSyntax {
+ p.fmt.sharp = false
+ if stringer, ok := field.(GoStringer); ok {
+ // Print the result of GoString unadorned.
+ p.fmtString(stringer.GoString(), 's', false, field)
+ return false // this value is not a string
+ }
+ } else {
+ // Is it a Stringer?
+ if stringer, ok := field.(Stringer); ok {
+ p.printField(stringer.String(), verb, plus, false, depth)
+ return false // this value is not a string
+ }
+ }
+
+ // Some types can be done without reflection.
+ switch f := field.(type) {
+ case bool:
+ p.fmtBool(f, verb, field)
+ return false
+ case float32:
+ p.fmtFloat32(f, verb, field)
+ return false
+ case float64:
+ p.fmtFloat64(f, verb, field)
+ return false
+ case complex64:
+ p.fmtComplex64(complex64(f), verb, field)
+ return false
+ case complex128:
+ p.fmtComplex128(f, verb, field)
+ return false
+ case int:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int8:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int16:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int32:
+ p.fmtInt64(int64(f), verb, field)
+ return false
+ case int64:
+ p.fmtInt64(f, verb, field)
+ return false
+ case uint:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint8:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint16:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint32:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case uint64:
+ p.fmtUint64(f, verb, goSyntax, field)
+ return false
+ case uintptr:
+ p.fmtUint64(uint64(f), verb, goSyntax, field)
+ return false
+ case string:
+ p.fmtString(f, verb, goSyntax, field)
+ return verb == 's' || verb == 'v'
+ case []byte:
+ p.fmtBytes(f, verb, goSyntax, depth, field)
+ return verb == 's'
+ }
+
+ // Need to use reflection
+ value := reflect.NewValue(field)
+
+BigSwitch:
+ switch f := value; f.Kind() {
+ case reflect.Bool:
+ p.fmtBool(f.Bool(), verb, field)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ p.fmtInt64(f.Int(), verb, field)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ p.fmtUint64(uint64(f.Uint()), verb, goSyntax, field)
+ case reflect.Float32, reflect.Float64:
+ if f.Type().Size() == 4 {
+ p.fmtFloat32(float32(f.Float()), verb, field)
+ } else {
+ p.fmtFloat64(float64(f.Float()), verb, field)
+ }
+ case reflect.Complex64, reflect.Complex128:
+ if f.Type().Size() == 8 {
+ p.fmtComplex64(complex64(f.Complex()), verb, field)
+ } else {
+ p.fmtComplex128(complex128(f.Complex()), verb, field)
+ }
+ case reflect.String:
+ p.fmtString(f.String(), verb, goSyntax, field)
+ case reflect.Map:
+ if goSyntax {
+ p.buf.WriteString(f.Type().String())
+ p.buf.WriteByte('{')
+ } else {
+ p.buf.Write(mapBytes)
+ }
+ keys := f.MapKeys()
+ for i, key := range keys {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(key.Interface(), verb, plus, goSyntax, depth+1)
+ p.buf.WriteByte(':')
+ p.printField(f.MapIndex(key).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ case reflect.Struct:
+ if goSyntax {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ }
+ p.add('{')
+ v := f
+ t := v.Type()
+ for i := 0; i < v.NumField(); i++ {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ if plus || goSyntax {
+ if f := t.Field(i); f.Name != "" {
+ p.buf.WriteString(f.Name)
+ p.buf.WriteByte(':')
+ }
+ }
+ p.printField(getField(v, i).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ p.buf.WriteByte('}')
+ case reflect.Interface:
+ value := f.Elem()
+ if !value.IsValid() {
+ if goSyntax {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.Write(nilParenBytes)
+ } else {
+ p.buf.Write(nilAngleBytes)
+ }
+ } else {
+ return p.printField(value.Interface(), verb, plus, goSyntax, depth+1)
+ }
+ case reflect.Array, reflect.Slice:
+ // Byte slices are special.
+ if f.Type().Elem().Kind() == reflect.Uint8 {
+ // We know it's a slice of bytes, but we also know it does not have static type
+ // []byte, or it would have been caught above. Therefore we cannot convert
+ // it directly in the (slightly) obvious way: f.Interface().([]byte); it doesn't have
+ // that type, and we can't write an expression of the right type and do a
+ // conversion because we don't have a static way to write the right type.
+ // So we build a slice by hand. This is a rare case but it would be nice
+ // if reflection could help a little more.
+ bytes := make([]byte, f.Len())
+ for i := range bytes {
+ bytes[i] = byte(f.Index(i).Uint())
+ }
+ p.fmtBytes(bytes, verb, goSyntax, depth, field)
+ return verb == 's'
+ }
+ if goSyntax {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.WriteByte('{')
+ } else {
+ p.buf.WriteByte('[')
+ }
+ for i := 0; i < f.Len(); i++ {
+ if i > 0 {
+ if goSyntax {
+ p.buf.Write(commaSpaceBytes)
+ } else {
+ p.buf.WriteByte(' ')
+ }
+ }
+ p.printField(f.Index(i).Interface(), verb, plus, goSyntax, depth+1)
+ }
+ if goSyntax {
+ p.buf.WriteByte('}')
+ } else {
+ p.buf.WriteByte(']')
+ }
+ case reflect.Ptr:
+ v := f.Pointer()
+ // pointer to array or slice or struct? ok at top level
+ // but not embedded (avoid loops)
+ if v != 0 && depth == 0 {
+ switch a := f.Elem(); a.Kind() {
+ case reflect.Array, reflect.Slice:
+ p.buf.WriteByte('&')
+ p.printField(a.Interface(), verb, plus, goSyntax, depth+1)
+ break BigSwitch
+ case reflect.Struct:
+ p.buf.WriteByte('&')
+ p.printField(a.Interface(), verb, plus, goSyntax, depth+1)
+ break BigSwitch
+ }
+ }
+ if goSyntax {
+ p.buf.WriteByte('(')
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.WriteByte(')')
+ p.buf.WriteByte('(')
+ if v == 0 {
+ p.buf.Write(nilBytes)
+ } else {
+ p.fmt0x64(uint64(v), true)
+ }
+ p.buf.WriteByte(')')
+ break
+ }
+ if v == 0 {
+ p.buf.Write(nilAngleBytes)
+ break
+ }
+ p.fmt0x64(uint64(v), true)
+ case reflect.Chan, reflect.Func, reflect.UnsafePointer:
+ p.fmtPointer(field, value, verb, goSyntax)
+ default:
+ p.unknownType(f)
+ }
+ return false
+}
+
+// intFromArg gets the fieldnumth element of a. On return, isInt reports whether the argument has type int.
+func intFromArg(a []interface{}, end, i, fieldnum int) (num int, isInt bool, newi, newfieldnum int) {
+ newi, newfieldnum = end, fieldnum
+ if i < end && fieldnum < len(a) {
+ num, isInt = a[fieldnum].(int)
+ newi, newfieldnum = i+1, fieldnum+1
+ }
+ return
+}
+
+func (p *pp) doPrintf(format string, a []interface{}) {
+ end := len(format)
+ fieldnum := 0 // we process one field per non-trivial format
+ for i := 0; i < end; {
+ lasti := i
+ for i < end && format[i] != '%' {
+ i++
+ }
+ if i > lasti {
+ p.buf.WriteString(format[lasti:i])
+ }
+ if i >= end {
+ // done processing format string
+ break
+ }
+
+ // Process one verb
+ i++
+ // flags and widths
+ p.fmt.clearflags()
+ F:
+ for ; i < end; i++ {
+ switch format[i] {
+ case '#':
+ p.fmt.sharp = true
+ case '0':
+ p.fmt.zero = true
+ case '+':
+ p.fmt.plus = true
+ case '-':
+ p.fmt.minus = true
+ case ' ':
+ p.fmt.space = true
+ default:
+ break F
+ }
+ }
+ // do we have width?
+ if i < end && format[i] == '*' {
+ p.fmt.wid, p.fmt.widPresent, i, fieldnum = intFromArg(a, end, i, fieldnum)
+ if !p.fmt.widPresent {
+ p.buf.Write(widthBytes)
+ }
+ } else {
+ p.fmt.wid, p.fmt.widPresent, i = parsenum(format, i, end)
+ }
+ // do we have precision?
+ if i < end && format[i] == '.' {
+ if format[i+1] == '*' {
+ p.fmt.prec, p.fmt.precPresent, i, fieldnum = intFromArg(a, end, i+1, fieldnum)
+ if !p.fmt.precPresent {
+ p.buf.Write(precBytes)
+ }
+ } else {
+ p.fmt.prec, p.fmt.precPresent, i = parsenum(format, i+1, end)
+ }
+ }
+ if i >= end {
+ p.buf.Write(noVerbBytes)
+ continue
+ }
+ c, w := utf8.DecodeRuneInString(format[i:])
+ i += w
+ // percent is special - absorbs no operand
+ if c == '%' {
+ p.buf.WriteByte('%') // We ignore width and prec.
+ continue
+ }
+ if fieldnum >= len(a) { // out of operands
+ p.buf.WriteByte('%')
+ p.add(c)
+ p.buf.Write(missingBytes)
+ continue
+ }
+ field := a[fieldnum]
+ fieldnum++
+
+ goSyntax := c == 'v' && p.fmt.sharp
+ plus := c == 'v' && p.fmt.plus
+ p.printField(field, c, plus, goSyntax, 0)
+ }
+
+ if fieldnum < len(a) {
+ p.buf.Write(extraBytes)
+ for ; fieldnum < len(a); fieldnum++ {
+ field := a[fieldnum]
+ if field != nil {
+ p.buf.WriteString(reflect.Typeof(field).String())
+ p.buf.WriteByte('=')
+ }
+ p.printField(field, 'v', false, false, 0)
+ if fieldnum+1 < len(a) {
+ p.buf.Write(commaSpaceBytes)
+ }
+ }
+ p.buf.WriteByte(')')
+ }
+}
+
+func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) {
+ prevString := false
+ for fieldnum := 0; fieldnum < len(a); fieldnum++ {
+ p.fmt.clearflags()
+ // always add spaces if we're doing println
+ field := a[fieldnum]
+ if fieldnum > 0 {
+ isString := field != nil && reflect.Typeof(field).Kind() == reflect.String
+ if addspace || !isString && !prevString {
+ p.buf.WriteByte(' ')
+ }
+ }
+ prevString = p.printField(field, 'v', false, false, 0)
+ }
+ if addnewline {
+ p.buf.WriteByte('\n')
+ }
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package implements utility functions to help with black box testing.
+package quick
+
+import (
+ "flag"
+ "fmt"
+ "math"
+ "os"
+ "rand"
+ "reflect"
+ "strings"
+)
+
+var defaultMaxCount *int = flag.Int("quickchecks", 100, "The default number of iterations for each check")
+
+// A Generator can generate random values of its own type.
+type Generator interface {
+ // Generate returns a random instance of the type on which it is a
+ // method using the size as a size hint.
+ Generate(rand *rand.Rand, size int) reflect.Value
+}
+
+// randFloat32 generates a random float taking the full range of a float32.
+func randFloat32(rand *rand.Rand) float32 {
+ f := rand.Float64() * math.MaxFloat32
+ if rand.Int()&1 == 1 {
+ f = -f
+ }
+ return float32(f)
+}
+
+// randFloat64 generates a random float taking the full range of a float64.
+func randFloat64(rand *rand.Rand) float64 {
+ f := rand.Float64()
+ if rand.Int()&1 == 1 {
+ f = -f
+ }
+ return f
+}
+
+// randInt64 returns a random integer taking half the range of an int64.
+func randInt64(rand *rand.Rand) int64 { return rand.Int63() - 1<<62 }
+
+// complexSize is the maximum length of arbitrary values that contain other
+// values.
+const complexSize = 50
+
+// Value returns an arbitrary value of the given type.
+// If the type implements the Generator interface, that will be used.
+// Note: in order to create arbitrary values for structs, all the members must be public.
+func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) {
+ if m, ok := reflect.MakeZero(t).Interface().(Generator); ok {
+ return m.Generate(rand, complexSize), true
+ }
+
+ switch concrete := t.(type) {
+ case *reflect.BoolType:
+ return reflect.NewValue(rand.Int()&1 == 0), true
+ case *reflect.FloatType, *reflect.IntType, *reflect.UintType, *reflect.ComplexType:
+ switch t.Kind() {
+ case reflect.Float32:
+ return reflect.NewValue(randFloat32(rand)), true
+ case reflect.Float64:
+ return reflect.NewValue(randFloat64(rand)), true
+ case reflect.Complex64:
+ return reflect.NewValue(complex(randFloat32(rand), randFloat32(rand))), true
+ case reflect.Complex128:
+ return reflect.NewValue(complex(randFloat64(rand), randFloat64(rand))), true
+ case reflect.Int16:
+ return reflect.NewValue(int16(randInt64(rand))), true
+ case reflect.Int32:
+ return reflect.NewValue(int32(randInt64(rand))), true
+ case reflect.Int64:
+ return reflect.NewValue(randInt64(rand)), true
+ case reflect.Int8:
+ return reflect.NewValue(int8(randInt64(rand))), true
+ case reflect.Int:
+ return reflect.NewValue(int(randInt64(rand))), true
+ case reflect.Uint16:
+ return reflect.NewValue(uint16(randInt64(rand))), true
+ case reflect.Uint32:
+ return reflect.NewValue(uint32(randInt64(rand))), true
+ case reflect.Uint64:
+ return reflect.NewValue(uint64(randInt64(rand))), true
+ case reflect.Uint8:
+ return reflect.NewValue(uint8(randInt64(rand))), true
+ case reflect.Uint:
+ return reflect.NewValue(uint(randInt64(rand))), true
+ case reflect.Uintptr:
+ return reflect.NewValue(uintptr(randInt64(rand))), true
+ }
+ case *reflect.MapType:
+ numElems := rand.Intn(complexSize)
+ m := reflect.MakeMap(concrete)
+ for i := 0; i < numElems; i++ {
+ key, ok1 := Value(concrete.Key(), rand)
+ value, ok2 := Value(concrete.Elem(), rand)
+ if !ok1 || !ok2 {
+ return nil, false
+ }
+ m.SetElem(key, value)
+ }
+ return m, true
+ case *reflect.PtrType:
+ v, ok := Value(concrete.Elem(), rand)
+ if !ok {
+ return nil, false
+ }
+ p := reflect.MakeZero(concrete)
+ p.(*reflect.PtrValue).PointTo(v)
+ return p, true
+ case *reflect.SliceType:
+ numElems := rand.Intn(complexSize)
+ s := reflect.MakeSlice(concrete, numElems, numElems)
+ for i := 0; i < numElems; i++ {
+ v, ok := Value(concrete.Elem(), rand)
+ if !ok {
+ return nil, false
+ }
+ s.Elem(i).SetValue(v)
+ }
+ return s, true
+ case *reflect.StringType:
+ numChars := rand.Intn(complexSize)
+ codePoints := make([]int, numChars)
+ for i := 0; i < numChars; i++ {
+ codePoints[i] = rand.Intn(0x10ffff)
+ }
+ return reflect.NewValue(string(codePoints)), true
+ case *reflect.StructType:
+ s := reflect.MakeZero(t).(*reflect.StructValue)
+ for i := 0; i < s.NumField(); i++ {
+ v, ok := Value(concrete.Field(i).Type, rand)
+ if !ok {
+ return nil, false
+ }
+ s.Field(i).SetValue(v)
+ }
+ return s, true
+ default:
+ return nil, false
+ }
+
+ return
+}
+
+// A Config structure contains options for running a test.
+type Config struct {
+ // MaxCount sets the maximum number of iterations. If zero,
+ // MaxCountScale is used.
+ MaxCount int
+ // MaxCountScale is a non-negative scale factor applied to the default
+ // maximum. If zero, the default is unchanged.
+ MaxCountScale float64
+ // If non-nil, rand is a source of random numbers. Otherwise a default
+ // pseudo-random source will be used.
+ Rand *rand.Rand
+ // If non-nil, Values is a function which generates a slice of arbitrary
+ // Values that are congruent with the arguments to the function being
+ // tested. Otherwise, Values is used to generate the values.
+ Values func([]reflect.Value, *rand.Rand)
+}
+
+var defaultConfig Config
+
+// getRand returns the *rand.Rand to use for a given Config.
+func (c *Config) getRand() *rand.Rand {
+ if c.Rand == nil {
+ return rand.New(rand.NewSource(0))
+ }
+ return c.Rand
+}
+
+// getMaxCount returns the maximum number of iterations to run for a given
+// Config.
+func (c *Config) getMaxCount() (maxCount int) {
+ maxCount = c.MaxCount
+ if maxCount == 0 {
+ if c.MaxCountScale != 0 {
+ maxCount = int(c.MaxCountScale * float64(*defaultMaxCount))
+ } else {
+ maxCount = *defaultMaxCount
+ }
+ }
+
+ return
+}
+
+// A SetupError is the result of an error in the way that check is being
+// used, independent of the functions being tested.
+type SetupError string
+
+func (s SetupError) String() string { return string(s) }
+
+// A CheckError is the result of Check finding an error.
+type CheckError struct {
+ Count int
+ In []interface{}
+}
+
+func (s *CheckError) String() string {
+ return fmt.Sprintf("#%d: failed on input %s", s.Count, toString(s.In))
+}
+
+// A CheckEqualError is the result CheckEqual finding an error.
+type CheckEqualError struct {
+ CheckError
+ Out1 []interface{}
+ Out2 []interface{}
+}
+
+func (s *CheckEqualError) String() string {
+ return fmt.Sprintf("#%d: failed on input %s. Output 1: %s. Output 2: %s", s.Count, toString(s.In), toString(s.Out1), toString(s.Out2))
+}
+
+// Check looks for an input to f, any function that returns bool,
+// such that f returns false. It calls f repeatedly, with arbitrary
+// values for each argument. If f returns false on a given input,
+// Check returns that input as a *CheckError.
+// For example:
+//
+// func TestOddMultipleOfThree(t *testing.T) {
+// f := func(x int) bool {
+// y := OddMultipleOfThree(x)
+// return y%2 == 1 && y%3 == 0
+// }
+// if err := quick.Check(f, nil); err != nil {
+// t.Error(err)
+// }
+// }
+func Check(function interface{}, config *Config) (err os.Error) {
+ if config == nil {
+ config = &defaultConfig
+ }
+
+ f, fType, ok := functionAndType(function)
+ if !ok {
+ err = SetupError("argument is not a function")
+ return
+ }
+
+ if fType.NumOut() != 1 {
+ err = SetupError("function returns more than one value.")
+ return
+ }
+ if _, ok := fType.Out(0).(*reflect.BoolType); !ok {
+ err = SetupError("function does not return a bool")
+ return
+ }
+
+ arguments := make([]reflect.Value, fType.NumIn())
+ rand := config.getRand()
+ maxCount := config.getMaxCount()
+
+ for i := 0; i < maxCount; i++ {
+ err = arbitraryValues(arguments, fType, config, rand)
+ if err != nil {
+ return
+ }
+
+ if !f.Call(arguments)[0].(*reflect.BoolValue).Get() {
+ err = &CheckError{i + 1, toInterfaces(arguments)}
+ return
+ }
+ }
+
+ return
+}
+
+// CheckEqual looks for an input on which f and g return different results.
+// It calls f and g repeatedly with arbitrary values for each argument.
+// If f and g return different answers, CheckEqual returns a *CheckEqualError
+// describing the input and the outputs.
+func CheckEqual(f, g interface{}, config *Config) (err os.Error) {
+ if config == nil {
+ config = &defaultConfig
+ }
+
+ x, xType, ok := functionAndType(f)
+ if !ok {
+ err = SetupError("f is not a function")
+ return
+ }
+ y, yType, ok := functionAndType(g)
+ if !ok {
+ err = SetupError("g is not a function")
+ return
+ }
+
+ if xType != yType {
+ err = SetupError("functions have different types")
+ return
+ }
+
+ arguments := make([]reflect.Value, xType.NumIn())
+ rand := config.getRand()
+ maxCount := config.getMaxCount()
+
+ for i := 0; i < maxCount; i++ {
+ err = arbitraryValues(arguments, xType, config, rand)
+ if err != nil {
+ return
+ }
+
+ xOut := toInterfaces(x.Call(arguments))
+ yOut := toInterfaces(y.Call(arguments))
+
+ if !reflect.DeepEqual(xOut, yOut) {
+ err = &CheckEqualError{CheckError{i + 1, toInterfaces(arguments)}, xOut, yOut}
+ return
+ }
+ }
+
+ return
+}
+
+// arbitraryValues writes Values to args such that args contains Values
+// suitable for calling f.
+func arbitraryValues(args []reflect.Value, f *reflect.FuncType, config *Config, rand *rand.Rand) (err os.Error) {
+ if config.Values != nil {
+ config.Values(args, rand)
+ return
+ }
+
+ for j := 0; j < len(args); j++ {
+ var ok bool
+ args[j], ok = Value(f.In(j), rand)
+ if !ok {
+ err = SetupError(fmt.Sprintf("cannot create arbitrary value of type %s for argument %d", f.In(j), j))
+ return
+ }
+ }
+
+ return
+}
+
+func functionAndType(f interface{}) (v *reflect.FuncValue, t *reflect.FuncType, ok bool) {
+ v, ok = reflect.NewValue(f).(*reflect.FuncValue)
+ if !ok {
+ return
+ }
+ t = v.Type().(*reflect.FuncType)
+ return
+}
+
+func toInterfaces(values []reflect.Value) []interface{} {
+ ret := make([]interface{}, len(values))
+ for i, v := range values {
+ ret[i] = v.Interface()
+ }
+ return ret
+}
+
+func toString(interfaces []interface{}) string {
+ s := make([]string, len(interfaces))
+ for i, v := range interfaces {
+ s[i] = fmt.Sprintf("%#v", v)
+ }
+ return strings.Join(s, ", ")
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package implements utility functions to help with black box testing.
+package quick
+
+import (
+ "flag"
+ "fmt"
+ "math"
+ "os"
+ "rand"
+ "reflect"
+ "strings"
+)
+
+var defaultMaxCount *int = flag.Int("quickchecks", 100, "The default number of iterations for each check")
+
+// A Generator can generate random values of its own type.
+type Generator interface {
+ // Generate returns a random instance of the type on which it is a
+ // method using the size as a size hint.
+ Generate(rand *rand.Rand, size int) reflect.Value
+}
+
+// randFloat32 generates a random float taking the full range of a float32.
+func randFloat32(rand *rand.Rand) float32 {
+ f := rand.Float64() * math.MaxFloat32
+ if rand.Int()&1 == 1 {
+ f = -f
+ }
+ return float32(f)
+}
+
+// randFloat64 generates a random float taking the full range of a float64.
+func randFloat64(rand *rand.Rand) float64 {
+ f := rand.Float64()
+ if rand.Int()&1 == 1 {
+ f = -f
+ }
+ return f
+}
+
+// randInt64 returns a random integer taking half the range of an int64.
+func randInt64(rand *rand.Rand) int64 { return rand.Int63() - 1<<62 }
+
+// complexSize is the maximum length of arbitrary values that contain other
+// values.
+const complexSize = 50
+
+// Value returns an arbitrary value of the given type.
+// If the type implements the Generator interface, that will be used.
+// Note: in order to create arbitrary values for structs, all the members must be public.
+func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) {
+ if m, ok := reflect.Zero(t).Interface().(Generator); ok {
+ return m.Generate(rand, complexSize), true
+ }
+
+ switch concrete := t; concrete.Kind() {
+ case reflect.Bool:
+ return reflect.NewValue(rand.Int()&1 == 0), true
+ case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Complex64, reflect.Complex128:
+ switch t.Kind() {
+ case reflect.Float32:
+ return reflect.NewValue(randFloat32(rand)), true
+ case reflect.Float64:
+ return reflect.NewValue(randFloat64(rand)), true
+ case reflect.Complex64:
+ return reflect.NewValue(complex(randFloat32(rand), randFloat32(rand))), true
+ case reflect.Complex128:
+ return reflect.NewValue(complex(randFloat64(rand), randFloat64(rand))), true
+ case reflect.Int16:
+ return reflect.NewValue(int16(randInt64(rand))), true
+ case reflect.Int32:
+ return reflect.NewValue(int32(randInt64(rand))), true
+ case reflect.Int64:
+ return reflect.NewValue(randInt64(rand)), true
+ case reflect.Int8:
+ return reflect.NewValue(int8(randInt64(rand))), true
+ case reflect.Int:
+ return reflect.NewValue(int(randInt64(rand))), true
+ case reflect.Uint16:
+ return reflect.NewValue(uint16(randInt64(rand))), true
+ case reflect.Uint32:
+ return reflect.NewValue(uint32(randInt64(rand))), true
+ case reflect.Uint64:
+ return reflect.NewValue(uint64(randInt64(rand))), true
+ case reflect.Uint8:
+ return reflect.NewValue(uint8(randInt64(rand))), true
+ case reflect.Uint:
+ return reflect.NewValue(uint(randInt64(rand))), true
+ case reflect.Uintptr:
+ return reflect.NewValue(uintptr(randInt64(rand))), true
+ }
+ case reflect.Map:
+ numElems := rand.Intn(complexSize)
+ m := reflect.MakeMap(concrete)
+ for i := 0; i < numElems; i++ {
+ key, ok1 := Value(concrete.Key(), rand)
+ value, ok2 := Value(concrete.Elem(), rand)
+ if !ok1 || !ok2 {
+ return reflect.Value{}, false
+ }
+ m.SetMapIndex(key, value)
+ }
+ return m, true
+ case reflect.Ptr:
+ v, ok := Value(concrete.Elem(), rand)
+ if !ok {
+ return reflect.Value{}, false
+ }
+ p := reflect.Zero(concrete)
+ p.Set(v.Addr())
+ return p, true
+ case reflect.Slice:
+ numElems := rand.Intn(complexSize)
+ s := reflect.MakeSlice(concrete, numElems, numElems)
+ for i := 0; i < numElems; i++ {
+ v, ok := Value(concrete.Elem(), rand)
+ if !ok {
+ return reflect.Value{}, false
+ }
+ s.Index(i).Set(v)
+ }
+ return s, true
+ case reflect.String:
+ numChars := rand.Intn(complexSize)
+ codePoints := make([]int, numChars)
+ for i := 0; i < numChars; i++ {
+ codePoints[i] = rand.Intn(0x10ffff)
+ }
+ return reflect.NewValue(string(codePoints)), true
+ case reflect.Struct:
+ s := reflect.Zero(t)
+ for i := 0; i < s.NumField(); i++ {
+ v, ok := Value(concrete.Field(i).Type, rand)
+ if !ok {
+ return reflect.Value{}, false
+ }
+ s.Field(i).Set(v)
+ }
+ return s, true
+ default:
+ return reflect.Value{}, false
+ }
+
+ return
+}
+
+// A Config structure contains options for running a test.
+type Config struct {
+ // MaxCount sets the maximum number of iterations. If zero,
+ // MaxCountScale is used.
+ MaxCount int
+ // MaxCountScale is a non-negative scale factor applied to the default
+ // maximum. If zero, the default is unchanged.
+ MaxCountScale float64
+ // If non-nil, rand is a source of random numbers. Otherwise a default
+ // pseudo-random source will be used.
+ Rand *rand.Rand
+ // If non-nil, Values is a function which generates a slice of arbitrary
+ // Values that are congruent with the arguments to the function being
+ // tested. Otherwise, Values is used to generate the values.
+ Values func([]reflect.Value, *rand.Rand)
+}
+
+var defaultConfig Config
+
+// getRand returns the *rand.Rand to use for a given Config.
+func (c *Config) getRand() *rand.Rand {
+ if c.Rand == nil {
+ return rand.New(rand.NewSource(0))
+ }
+ return c.Rand
+}
+
+// getMaxCount returns the maximum number of iterations to run for a given
+// Config.
+func (c *Config) getMaxCount() (maxCount int) {
+ maxCount = c.MaxCount
+ if maxCount == 0 {
+ if c.MaxCountScale != 0 {
+ maxCount = int(c.MaxCountScale * float64(*defaultMaxCount))
+ } else {
+ maxCount = *defaultMaxCount
+ }
+ }
+
+ return
+}
+
+// A SetupError is the result of an error in the way that check is being
+// used, independent of the functions being tested.
+type SetupError string
+
+func (s SetupError) String() string { return string(s) }
+
+// A CheckError is the result of Check finding an error.
+type CheckError struct {
+ Count int
+ In []interface{}
+}
+
+func (s *CheckError) String() string {
+ return fmt.Sprintf("#%d: failed on input %s", s.Count, toString(s.In))
+}
+
+// A CheckEqualError is the result CheckEqual finding an error.
+type CheckEqualError struct {
+ CheckError
+ Out1 []interface{}
+ Out2 []interface{}
+}
+
+func (s *CheckEqualError) String() string {
+ return fmt.Sprintf("#%d: failed on input %s. Output 1: %s. Output 2: %s", s.Count, toString(s.In), toString(s.Out1), toString(s.Out2))
+}
+
+// Check looks for an input to f, any function that returns bool,
+// such that f returns false. It calls f repeatedly, with arbitrary
+// values for each argument. If f returns false on a given input,
+// Check returns that input as a *CheckError.
+// For example:
+//
+// func TestOddMultipleOfThree(t *testing.T) {
+// f := func(x int) bool {
+// y := OddMultipleOfThree(x)
+// return y%2 == 1 && y%3 == 0
+// }
+// if err := quick.Check(f, nil); err != nil {
+// t.Error(err)
+// }
+// }
+func Check(function interface{}, config *Config) (err os.Error) {
+ if config == nil {
+ config = &defaultConfig
+ }
+
+ f, fType, ok := functionAndType(function)
+ if !ok {
+ err = SetupError("argument is not a function")
+ return
+ }
+
+ if fType.NumOut() != 1 {
+ err = SetupError("function returns more than one value.")
+ return
+ }
+ if fType.Out(0).Kind() != reflect.Bool {
+ err = SetupError("function does not return a bool")
+ return
+ }
+
+ arguments := make([]reflect.Value, fType.NumIn())
+ rand := config.getRand()
+ maxCount := config.getMaxCount()
+
+ for i := 0; i < maxCount; i++ {
+ err = arbitraryValues(arguments, fType, config, rand)
+ if err != nil {
+ return
+ }
+
+ if !f.Call(arguments)[0].Bool() {
+ err = &CheckError{i + 1, toInterfaces(arguments)}
+ return
+ }
+ }
+
+ return
+}
+
+// CheckEqual looks for an input on which f and g return different results.
+// It calls f and g repeatedly with arbitrary values for each argument.
+// If f and g return different answers, CheckEqual returns a *CheckEqualError
+// describing the input and the outputs.
+func CheckEqual(f, g interface{}, config *Config) (err os.Error) {
+ if config == nil {
+ config = &defaultConfig
+ }
+
+ x, xType, ok := functionAndType(f)
+ if !ok {
+ err = SetupError("f is not a function")
+ return
+ }
+ y, yType, ok := functionAndType(g)
+ if !ok {
+ err = SetupError("g is not a function")
+ return
+ }
+
+ if xType != yType {
+ err = SetupError("functions have different types")
+ return
+ }
+
+ arguments := make([]reflect.Value, xType.NumIn())
+ rand := config.getRand()
+ maxCount := config.getMaxCount()
+
+ for i := 0; i < maxCount; i++ {
+ err = arbitraryValues(arguments, xType, config, rand)
+ if err != nil {
+ return
+ }
+
+ xOut := toInterfaces(x.Call(arguments))
+ yOut := toInterfaces(y.Call(arguments))
+
+ if !reflect.DeepEqual(xOut, yOut) {
+ err = &CheckEqualError{CheckError{i + 1, toInterfaces(arguments)}, xOut, yOut}
+ return
+ }
+ }
+
+ return
+}
+
+// arbitraryValues writes Values to args such that args contains Values
+// suitable for calling f.
+func arbitraryValues(args []reflect.Value, f reflect.Type, config *Config, rand *rand.Rand) (err os.Error) {
+ if config.Values != nil {
+ config.Values(args, rand)
+ return
+ }
+
+ for j := 0; j < len(args); j++ {
+ var ok bool
+ args[j], ok = Value(f.In(j), rand)
+ if !ok {
+ err = SetupError(fmt.Sprintf("cannot create arbitrary value of type %s for argument %d", f.In(j), j))
+ return
+ }
+ }
+
+ return
+}
+
+func functionAndType(f interface{}) (v reflect.Value, t reflect.Type, ok bool) {
+ v = reflect.NewValue(f)
+ ok = v.Kind() == reflect.Func
+ if !ok {
+ return
+ }
+ t = v.Type()
+ return
+}
+
+func toInterfaces(values []reflect.Value) []interface{} {
+ ret := make([]interface{}, len(values))
+ for i, v := range values {
+ ret[i] = v.Interface()
+ }
+ return ret
+}
+
+func toString(interfaces []interface{}) string {
+ s := make([]string, len(interfaces))
+ for i, v := range interfaces {
+ s[i] = fmt.Sprintf("%#v", v)
+ }
+ return strings.Join(s, ", ")
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xml
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
+// an XML element is an order-dependent collection of anonymous
+// values, while a data structure is an order-independent collection
+// of named values.
+// See package json for a textual representation more suitable
+// to data structures.
+
+// Unmarshal parses an XML element from r and uses the
+// reflect library to fill in an arbitrary struct, slice, or string
+// pointed at by val. Well-formed data that does not fit
+// into val is discarded.
+//
+// For example, given these definitions:
+//
+// type Email struct {
+// Where string "attr"
+// Addr string
+// }
+//
+// type Result struct {
+// XMLName xml.Name "result"
+// Name string
+// Phone string
+// Email []Email
+// Groups []string "group>value"
+// }
+//
+// result := Result{Name: "name", Phone: "phone", Email: nil}
+//
+// unmarshalling the XML input
+//
+// <result>
+// <email where="home">
+// <addr>gre@example.com</addr>
+// </email>
+// <email where='work'>
+// <addr>gre@work.com</addr>
+// </email>
+// <name>Grace R. Emlin</name>
+// <group>
+// <value>Friends</value>
+// <value>Squash</value>
+// </group>
+// <address>123 Main Street</address>
+// </result>
+//
+// via Unmarshal(r, &result) is equivalent to assigning
+//
+// r = Result{xml.Name{"", "result"},
+// "Grace R. Emlin", // name
+// "phone", // no phone given
+// []Email{
+// Email{"home", "gre@example.com"},
+// Email{"work", "gre@work.com"},
+// },
+// []string{"Friends", "Squash"},
+// }
+//
+// Note that the field r.Phone has not been modified and
+// that the XML <address> element was discarded. Also, the field
+// Groups was assigned considering the element path provided in the
+// field tag.
+//
+// Because Unmarshal uses the reflect package, it can only
+// assign to upper case fields. Unmarshal uses a case-insensitive
+// comparison to match XML element names to struct field names.
+//
+// Unmarshal maps an XML element to a struct using the following rules:
+//
+// * If the struct has a field of type []byte or string with tag "innerxml",
+// Unmarshal accumulates the raw XML nested inside the element
+// in that field. The rest of the rules still apply.
+//
+// * If the struct has a field named XMLName of type xml.Name,
+// Unmarshal records the element name in that field.
+//
+// * If the XMLName field has an associated tag string of the form
+// "tag" or "namespace-URL tag", the XML element must have
+// the given tag (and, optionally, name space) or else Unmarshal
+// returns an error.
+//
+// * If the XML element has an attribute whose name matches a
+// struct field of type string with tag "attr", Unmarshal records
+// the attribute value in that field.
+//
+// * If the XML element contains character data, that data is
+// accumulated in the first struct field that has tag "chardata".
+// The struct field may have type []byte or string.
+// If there is no such field, the character data is discarded.
+//
+// * If the XML element contains a sub-element whose name matches
+// the prefix of a struct field tag formatted as "a>b>c", unmarshal
+// will descend into the XML structure looking for elements with the
+// given names, and will map the innermost elements to that struct field.
+// A struct field tag starting with ">" is equivalent to one starting
+// with the field name followed by ">".
+//
+// * If the XML element contains a sub-element whose name
+// matches a struct field whose tag is neither "attr" nor "chardata",
+// Unmarshal maps the sub-element to that struct field.
+// Otherwise, if the struct has a field named Any, unmarshal
+// maps the sub-element to that struct field.
+//
+// Unmarshal maps an XML element to a string or []byte by saving the
+// concatenation of that element's character data in the string or []byte.
+//
+// Unmarshal maps an XML element to a slice by extending the length
+// of the slice and mapping the element to the newly created value.
+//
+// Unmarshal maps an XML element to a bool by setting it to the boolean
+// value represented by the string.
+//
+// Unmarshal maps an XML element to an integer or floating-point
+// field by setting the field to the result of interpreting the string
+// value in decimal. There is no check for overflow.
+//
+// Unmarshal maps an XML element to an xml.Name by recording the
+// element name.
+//
+// Unmarshal maps an XML element to a pointer by setting the pointer
+// to a freshly allocated value and then mapping the element to that value.
+//
+func Unmarshal(r io.Reader, val interface{}) os.Error {
+ v, ok := reflect.NewValue(val).(*reflect.PtrValue)
+ if !ok {
+ return os.NewError("non-pointer passed to Unmarshal")
+ }
+ p := NewParser(r)
+ elem := v.Elem()
+ err := p.unmarshal(elem, nil)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// An UnmarshalError represents an error in the unmarshalling process.
+type UnmarshalError string
+
+func (e UnmarshalError) String() string { return string(e) }
+
+// A TagPathError represents an error in the unmarshalling process
+// caused by the use of field tags with conflicting paths.
+type TagPathError struct {
+ Struct reflect.Type
+ Field1, Tag1 string
+ Field2, Tag2 string
+}
+
+func (e *TagPathError) String() string {
+ return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
+}
+
+// The Parser's Unmarshal method is like xml.Unmarshal
+// except that it can be passed a pointer to the initial start element,
+// useful when a client reads some raw XML tokens itself
+// but also defers to Unmarshal for some elements.
+// Passing a nil start element indicates that Unmarshal should
+// read the token stream to find the start element.
+func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error {
+ v, ok := reflect.NewValue(val).(*reflect.PtrValue)
+ if !ok {
+ return os.NewError("non-pointer passed to Unmarshal")
+ }
+ return p.unmarshal(v.Elem(), start)
+}
+
+// fieldName strips invalid characters from an XML name
+// to create a valid Go struct name. It also converts the
+// name to lower case letters.
+func fieldName(original string) string {
+
+ var i int
+ //remove leading underscores
+ for i = 0; i < len(original) && original[i] == '_'; i++ {
+ }
+
+ return strings.Map(
+ func(x int) int {
+ if x == '_' || unicode.IsDigit(x) || unicode.IsLetter(x) {
+ return unicode.ToLower(x)
+ }
+ return -1
+ },
+ original[i:])
+}
+
+// Unmarshal a single XML element into val.
+func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error {
+ // Find start element if we need it.
+ if start == nil {
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ if t, ok := tok.(StartElement); ok {
+ start = &t
+ break
+ }
+ }
+ }
+
+ if pv, ok := val.(*reflect.PtrValue); ok {
+ if pv.Get() == 0 {
+ zv := reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem())
+ pv.PointTo(zv)
+ val = zv
+ } else {
+ val = pv.Elem()
+ }
+ }
+
+ var (
+ data []byte
+ saveData reflect.Value
+ comment []byte
+ saveComment reflect.Value
+ saveXML reflect.Value
+ saveXMLIndex int
+ saveXMLData []byte
+ sv *reflect.StructValue
+ styp *reflect.StructType
+ fieldPaths map[string]pathInfo
+ )
+
+ switch v := val.(type) {
+ default:
+ return os.ErrorString("unknown type " + v.Type().String())
+
+ case *reflect.SliceValue:
+ typ := v.Type().(*reflect.SliceType)
+ if typ.Elem().Kind() == reflect.Uint8 {
+ // []byte
+ saveData = v
+ break
+ }
+
+ // Slice of element values.
+ // Grow slice.
+ n := v.Len()
+ if n >= v.Cap() {
+ ncap := 2 * n
+ if ncap < 4 {
+ ncap = 4
+ }
+ new := reflect.MakeSlice(typ, n, ncap)
+ reflect.Copy(new, v)
+ v.Set(new)
+ }
+ v.SetLen(n + 1)
+
+ // Recur to read element into slice.
+ if err := p.unmarshal(v.Elem(n), start); err != nil {
+ v.SetLen(n)
+ return err
+ }
+ return nil
+
+ case *reflect.BoolValue, *reflect.FloatValue, *reflect.IntValue, *reflect.UintValue, *reflect.StringValue:
+ saveData = v
+
+ case *reflect.StructValue:
+ if _, ok := v.Interface().(Name); ok {
+ v.Set(reflect.NewValue(start.Name).(*reflect.StructValue))
+ break
+ }
+
+ sv = v
+ typ := sv.Type().(*reflect.StructType)
+ styp = typ
+ // Assign name.
+ if f, ok := typ.FieldByName("XMLName"); ok {
+ // Validate element name.
+ if f.Tag != "" {
+ tag := f.Tag
+ ns := ""
+ i := strings.LastIndex(tag, " ")
+ if i >= 0 {
+ ns, tag = tag[0:i], tag[i+1:]
+ }
+ if tag != start.Name.Local {
+ return UnmarshalError("expected element type <" + tag + "> but have <" + start.Name.Local + ">")
+ }
+ if ns != "" && ns != start.Name.Space {
+ e := "expected element <" + tag + "> in name space " + ns + " but have "
+ if start.Name.Space == "" {
+ e += "no name space"
+ } else {
+ e += start.Name.Space
+ }
+ return UnmarshalError(e)
+ }
+ }
+
+ // Save
+ v := sv.FieldByIndex(f.Index)
+ if _, ok := v.Interface().(Name); !ok {
+ return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name")
+ }
+ v.(*reflect.StructValue).Set(reflect.NewValue(start.Name).(*reflect.StructValue))
+ }
+
+ // Assign attributes.
+ // Also, determine whether we need to save character data or comments.
+ for i, n := 0, typ.NumField(); i < n; i++ {
+ f := typ.Field(i)
+ switch f.Tag {
+ case "attr":
+ strv, ok := sv.FieldByIndex(f.Index).(*reflect.StringValue)
+ if !ok {
+ return UnmarshalError(sv.Type().String() + " field " + f.Name + " has attr tag but is not type string")
+ }
+ // Look for attribute.
+ val := ""
+ k := strings.ToLower(f.Name)
+ for _, a := range start.Attr {
+ if fieldName(a.Name.Local) == k {
+ val = a.Value
+ break
+ }
+ }
+ strv.Set(val)
+
+ case "comment":
+ if saveComment == nil {
+ saveComment = sv.FieldByIndex(f.Index)
+ }
+
+ case "chardata":
+ if saveData == nil {
+ saveData = sv.FieldByIndex(f.Index)
+ }
+
+ case "innerxml":
+ if saveXML == nil {
+ saveXML = sv.FieldByIndex(f.Index)
+ if p.saved == nil {
+ saveXMLIndex = 0
+ p.saved = new(bytes.Buffer)
+ } else {
+ saveXMLIndex = p.savedOffset()
+ }
+ }
+
+ default:
+ if strings.Contains(f.Tag, ">") {
+ if fieldPaths == nil {
+ fieldPaths = make(map[string]pathInfo)
+ }
+ path := strings.ToLower(f.Tag)
+ if strings.HasPrefix(f.Tag, ">") {
+ path = strings.ToLower(f.Name) + path
+ }
+ if strings.HasSuffix(f.Tag, ">") {
+ path = path[:len(path)-1]
+ }
+ err := addFieldPath(sv, fieldPaths, path, f.Index)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ }
+
+ // Find end element.
+ // Process sub-elements along the way.
+Loop:
+ for {
+ var savedOffset int
+ if saveXML != nil {
+ savedOffset = p.savedOffset()
+ }
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ // Sub-element.
+ // Look up by tag name.
+ if sv != nil {
+ k := fieldName(t.Name.Local)
+
+ if fieldPaths != nil {
+ if _, found := fieldPaths[k]; found {
+ if err := p.unmarshalPaths(sv, fieldPaths, k, &t); err != nil {
+ return err
+ }
+ continue Loop
+ }
+ }
+
+ match := func(s string) bool {
+ // check if the name matches ignoring case
+ if strings.ToLower(s) != k {
+ return false
+ }
+ // now check that it's public
+ c, _ := utf8.DecodeRuneInString(s)
+ return unicode.IsUpper(c)
+ }
+
+ f, found := styp.FieldByNameFunc(match)
+ if !found { // fall back to mop-up field named "Any"
+ f, found = styp.FieldByName("Any")
+ }
+ if found {
+ if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil {
+ return err
+ }
+ continue Loop
+ }
+ }
+ // Not saving sub-element but still have to skip over it.
+ if err := p.Skip(); err != nil {
+ return err
+ }
+
+ case EndElement:
+ if saveXML != nil {
+ saveXMLData = p.saved.Bytes()[saveXMLIndex:savedOffset]
+ if saveXMLIndex == 0 {
+ p.saved = nil
+ }
+ }
+ break Loop
+
+ case CharData:
+ if saveData != nil {
+ data = append(data, t...)
+ }
+
+ case Comment:
+ if saveComment != nil {
+ comment = append(comment, t...)
+ }
+ }
+ }
+
+ var err os.Error
+ // Helper functions for integer and unsigned integer conversions
+ var itmp int64
+ getInt64 := func() bool {
+ itmp, err = strconv.Atoi64(string(data))
+ // TODO: should check sizes
+ return err == nil
+ }
+ var utmp uint64
+ getUint64 := func() bool {
+ utmp, err = strconv.Atoui64(string(data))
+ // TODO: check for overflow?
+ return err == nil
+ }
+ var ftmp float64
+ getFloat64 := func() bool {
+ ftmp, err = strconv.Atof64(string(data))
+ // TODO: check for overflow?
+ return err == nil
+ }
+
+ // Save accumulated data and comments
+ switch t := saveData.(type) {
+ case nil:
+ // Probably a comment, handled below
+ default:
+ return os.ErrorString("cannot happen: unknown type " + t.Type().String())
+ case *reflect.IntValue:
+ if !getInt64() {
+ return err
+ }
+ t.Set(itmp)
+ case *reflect.UintValue:
+ if !getUint64() {
+ return err
+ }
+ t.Set(utmp)
+ case *reflect.FloatValue:
+ if !getFloat64() {
+ return err
+ }
+ t.Set(ftmp)
+ case *reflect.BoolValue:
+ value, err := strconv.Atob(strings.TrimSpace(string(data)))
+ if err != nil {
+ return err
+ }
+ t.Set(value)
+ case *reflect.StringValue:
+ t.Set(string(data))
+ case *reflect.SliceValue:
+ t.Set(reflect.NewValue(data).(*reflect.SliceValue))
+ }
+
+ switch t := saveComment.(type) {
+ case *reflect.StringValue:
+ t.Set(string(comment))
+ case *reflect.SliceValue:
+ t.Set(reflect.NewValue(comment).(*reflect.SliceValue))
+ }
+
+ switch t := saveXML.(type) {
+ case *reflect.StringValue:
+ t.Set(string(saveXMLData))
+ case *reflect.SliceValue:
+ t.Set(reflect.NewValue(saveXMLData).(*reflect.SliceValue))
+ }
+
+ return nil
+}
+
+type pathInfo struct {
+ fieldIdx []int
+ complete bool
+}
+
+// addFieldPath takes an element path such as "a>b>c" and fills the
+// paths map with all paths leading to it ("a", "a>b", and "a>b>c").
+// It is okay for paths to share a common, shorter prefix but not ok
+// for one path to itself be a prefix of another.
+func addFieldPath(sv *reflect.StructValue, paths map[string]pathInfo, path string, fieldIdx []int) os.Error {
+ if info, found := paths[path]; found {
+ return tagError(sv, info.fieldIdx, fieldIdx)
+ }
+ paths[path] = pathInfo{fieldIdx, true}
+ for {
+ i := strings.LastIndex(path, ">")
+ if i < 0 {
+ break
+ }
+ path = path[:i]
+ if info, found := paths[path]; found {
+ if info.complete {
+ return tagError(sv, info.fieldIdx, fieldIdx)
+ }
+ } else {
+ paths[path] = pathInfo{fieldIdx, false}
+ }
+ }
+ return nil
+
+}
+
+func tagError(sv *reflect.StructValue, idx1 []int, idx2 []int) os.Error {
+ t := sv.Type().(*reflect.StructType)
+ f1 := t.FieldByIndex(idx1)
+ f2 := t.FieldByIndex(idx2)
+ return &TagPathError{t, f1.Name, f1.Tag, f2.Name, f2.Tag}
+}
+
+// unmarshalPaths walks down an XML structure looking for
+// wanted paths, and calls unmarshal on them.
+func (p *Parser) unmarshalPaths(sv *reflect.StructValue, paths map[string]pathInfo, path string, start *StartElement) os.Error {
+ if info, _ := paths[path]; info.complete {
+ return p.unmarshal(sv.FieldByIndex(info.fieldIdx), start)
+ }
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ k := path + ">" + fieldName(t.Name.Local)
+ if _, found := paths[k]; found {
+ if err := p.unmarshalPaths(sv, paths, k, &t); err != nil {
+ return err
+ }
+ continue
+ }
+ if err := p.Skip(); err != nil {
+ return err
+ }
+ case EndElement:
+ return nil
+ }
+ }
+ panic("unreachable")
+}
+
+// Have already read a start element.
+// Read tokens until we find the end element.
+// Token is taking care of making sure the
+// end element matches the start element we saw.
+func (p *Parser) Skip() os.Error {
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ if err := p.Skip(); err != nil {
+ return err
+ }
+ case EndElement:
+ return nil
+ }
+ }
+ panic("unreachable")
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xml
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
+// an XML element is an order-dependent collection of anonymous
+// values, while a data structure is an order-independent collection
+// of named values.
+// See package json for a textual representation more suitable
+// to data structures.
+
+// Unmarshal parses an XML element from r and uses the
+// reflect library to fill in an arbitrary struct, slice, or string
+// pointed at by val. Well-formed data that does not fit
+// into val is discarded.
+//
+// For example, given these definitions:
+//
+// type Email struct {
+// Where string "attr"
+// Addr string
+// }
+//
+// type Result struct {
+// XMLName xml.Name "result"
+// Name string
+// Phone string
+// Email []Email
+// Groups []string "group>value"
+// }
+//
+// result := Result{Name: "name", Phone: "phone", Email: nil}
+//
+// unmarshalling the XML input
+//
+// <result>
+// <email where="home">
+// <addr>gre@example.com</addr>
+// </email>
+// <email where='work'>
+// <addr>gre@work.com</addr>
+// </email>
+// <name>Grace R. Emlin</name>
+// <group>
+// <value>Friends</value>
+// <value>Squash</value>
+// </group>
+// <address>123 Main Street</address>
+// </result>
+//
+// via Unmarshal(r, &result) is equivalent to assigning
+//
+// r = Result{xml.Name{"", "result"},
+// "Grace R. Emlin", // name
+// "phone", // no phone given
+// []Email{
+// Email{"home", "gre@example.com"},
+// Email{"work", "gre@work.com"},
+// },
+// []string{"Friends", "Squash"},
+// }
+//
+// Note that the field r.Phone has not been modified and
+// that the XML <address> element was discarded. Also, the field
+// Groups was assigned considering the element path provided in the
+// field tag.
+//
+// Because Unmarshal uses the reflect package, it can only
+// assign to upper case fields. Unmarshal uses a case-insensitive
+// comparison to match XML element names to struct field names.
+//
+// Unmarshal maps an XML element to a struct using the following rules:
+//
+// * If the struct has a field of type []byte or string with tag "innerxml",
+// Unmarshal accumulates the raw XML nested inside the element
+// in that field. The rest of the rules still apply.
+//
+// * If the struct has a field named XMLName of type xml.Name,
+// Unmarshal records the element name in that field.
+//
+// * If the XMLName field has an associated tag string of the form
+// "tag" or "namespace-URL tag", the XML element must have
+// the given tag (and, optionally, name space) or else Unmarshal
+// returns an error.
+//
+// * If the XML element has an attribute whose name matches a
+// struct field of type string with tag "attr", Unmarshal records
+// the attribute value in that field.
+//
+// * If the XML element contains character data, that data is
+// accumulated in the first struct field that has tag "chardata".
+// The struct field may have type []byte or string.
+// If there is no such field, the character data is discarded.
+//
+// * If the XML element contains a sub-element whose name matches
+// the prefix of a struct field tag formatted as "a>b>c", unmarshal
+// will descend into the XML structure looking for elements with the
+// given names, and will map the innermost elements to that struct field.
+// A struct field tag starting with ">" is equivalent to one starting
+// with the field name followed by ">".
+//
+// * If the XML element contains a sub-element whose name
+// matches a struct field whose tag is neither "attr" nor "chardata",
+// Unmarshal maps the sub-element to that struct field.
+// Otherwise, if the struct has a field named Any, unmarshal
+// maps the sub-element to that struct field.
+//
+// Unmarshal maps an XML element to a string or []byte by saving the
+// concatenation of that element's character data in the string or []byte.
+//
+// Unmarshal maps an XML element to a slice by extending the length
+// of the slice and mapping the element to the newly created value.
+//
+// Unmarshal maps an XML element to a bool by setting it to the boolean
+// value represented by the string.
+//
+// Unmarshal maps an XML element to an integer or floating-point
+// field by setting the field to the result of interpreting the string
+// value in decimal. There is no check for overflow.
+//
+// Unmarshal maps an XML element to an xml.Name by recording the
+// element name.
+//
+// Unmarshal maps an XML element to a pointer by setting the pointer
+// to a freshly allocated value and then mapping the element to that value.
+//
+func Unmarshal(r io.Reader, val interface{}) os.Error {
+ v := reflect.NewValue(val)
+ if v.Kind() != reflect.Ptr {
+ return os.NewError("non-pointer passed to Unmarshal")
+ }
+ p := NewParser(r)
+ elem := v.Elem()
+ err := p.unmarshal(elem, nil)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// An UnmarshalError represents an error in the unmarshalling process.
+type UnmarshalError string
+
+func (e UnmarshalError) String() string { return string(e) }
+
+// A TagPathError represents an error in the unmarshalling process
+// caused by the use of field tags with conflicting paths.
+type TagPathError struct {
+ Struct reflect.Type
+ Field1, Tag1 string
+ Field2, Tag2 string
+}
+
+func (e *TagPathError) String() string {
+ return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
+}
+
+// The Parser's Unmarshal method is like xml.Unmarshal
+// except that it can be passed a pointer to the initial start element,
+// useful when a client reads some raw XML tokens itself
+// but also defers to Unmarshal for some elements.
+// Passing a nil start element indicates that Unmarshal should
+// read the token stream to find the start element.
+func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error {
+ v := reflect.NewValue(val)
+ if v.Kind() != reflect.Ptr {
+ return os.NewError("non-pointer passed to Unmarshal")
+ }
+ return p.unmarshal(v.Elem(), start)
+}
+
+// fieldName strips invalid characters from an XML name
+// to create a valid Go struct name. It also converts the
+// name to lower case letters.
+func fieldName(original string) string {
+
+ var i int
+ //remove leading underscores
+ for i = 0; i < len(original) && original[i] == '_'; i++ {
+ }
+
+ return strings.Map(
+ func(x int) int {
+ if x == '_' || unicode.IsDigit(x) || unicode.IsLetter(x) {
+ return unicode.ToLower(x)
+ }
+ return -1
+ },
+ original[i:])
+}
+
+// Unmarshal a single XML element into val.
+func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error {
+ // Find start element if we need it.
+ if start == nil {
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ if t, ok := tok.(StartElement); ok {
+ start = &t
+ break
+ }
+ }
+ }
+
+ if pv := val; pv.Kind() == reflect.Ptr {
+ if pv.Pointer() == 0 {
+ zv := reflect.Zero(pv.Type().Elem())
+ pv.Set(zv.Addr())
+ val = zv
+ } else {
+ val = pv.Elem()
+ }
+ }
+
+ var (
+ data []byte
+ saveData reflect.Value
+ comment []byte
+ saveComment reflect.Value
+ saveXML reflect.Value
+ saveXMLIndex int
+ saveXMLData []byte
+ sv reflect.Value
+ styp reflect.Type
+ fieldPaths map[string]pathInfo
+ )
+
+ switch v := val; v.Kind() {
+ default:
+ return os.ErrorString("unknown type " + v.Type().String())
+
+ case reflect.Slice:
+ typ := v.Type()
+ if typ.Elem().Kind() == reflect.Uint8 {
+ // []byte
+ saveData = v
+ break
+ }
+
+ // Slice of element values.
+ // Grow slice.
+ n := v.Len()
+ if n >= v.Cap() {
+ ncap := 2 * n
+ if ncap < 4 {
+ ncap = 4
+ }
+ new := reflect.MakeSlice(typ, n, ncap)
+ reflect.Copy(new, v)
+ v.Set(new)
+ }
+ v.SetLen(n + 1)
+
+ // Recur to read element into slice.
+ if err := p.unmarshal(v.Index(n), start); err != nil {
+ v.SetLen(n)
+ return err
+ }
+ return nil
+
+ case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String:
+ saveData = v
+
+ case reflect.Struct:
+ if _, ok := v.Interface().(Name); ok {
+ v.Set(reflect.NewValue(start.Name))
+ break
+ }
+
+ sv = v
+ typ := sv.Type()
+ styp = typ
+ // Assign name.
+ if f, ok := typ.FieldByName("XMLName"); ok {
+ // Validate element name.
+ if f.Tag != "" {
+ tag := f.Tag
+ ns := ""
+ i := strings.LastIndex(tag, " ")
+ if i >= 0 {
+ ns, tag = tag[0:i], tag[i+1:]
+ }
+ if tag != start.Name.Local {
+ return UnmarshalError("expected element type <" + tag + "> but have <" + start.Name.Local + ">")
+ }
+ if ns != "" && ns != start.Name.Space {
+ e := "expected element <" + tag + "> in name space " + ns + " but have "
+ if start.Name.Space == "" {
+ e += "no name space"
+ } else {
+ e += start.Name.Space
+ }
+ return UnmarshalError(e)
+ }
+ }
+
+ // Save
+ v := sv.FieldByIndex(f.Index)
+ if _, ok := v.Interface().(Name); !ok {
+ return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name")
+ }
+ v.Set(reflect.NewValue(start.Name))
+ }
+
+ // Assign attributes.
+ // Also, determine whether we need to save character data or comments.
+ for i, n := 0, typ.NumField(); i < n; i++ {
+ f := typ.Field(i)
+ switch f.Tag {
+ case "attr":
+ strv := sv.FieldByIndex(f.Index)
+ if strv.Kind() != reflect.String {
+ return UnmarshalError(sv.Type().String() + " field " + f.Name + " has attr tag but is not type string")
+ }
+ // Look for attribute.
+ val := ""
+ k := strings.ToLower(f.Name)
+ for _, a := range start.Attr {
+ if fieldName(a.Name.Local) == k {
+ val = a.Value
+ break
+ }
+ }
+ strv.SetString(val)
+
+ case "comment":
+ if !saveComment.IsValid() {
+ saveComment = sv.FieldByIndex(f.Index)
+ }
+
+ case "chardata":
+ if !saveData.IsValid() {
+ saveData = sv.FieldByIndex(f.Index)
+ }
+
+ case "innerxml":
+ if !saveXML.IsValid() {
+ saveXML = sv.FieldByIndex(f.Index)
+ if p.saved == nil {
+ saveXMLIndex = 0
+ p.saved = new(bytes.Buffer)
+ } else {
+ saveXMLIndex = p.savedOffset()
+ }
+ }
+
+ default:
+ if strings.Contains(f.Tag, ">") {
+ if fieldPaths == nil {
+ fieldPaths = make(map[string]pathInfo)
+ }
+ path := strings.ToLower(f.Tag)
+ if strings.HasPrefix(f.Tag, ">") {
+ path = strings.ToLower(f.Name) + path
+ }
+ if strings.HasSuffix(f.Tag, ">") {
+ path = path[:len(path)-1]
+ }
+ err := addFieldPath(sv, fieldPaths, path, f.Index)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ }
+
+ // Find end element.
+ // Process sub-elements along the way.
+Loop:
+ for {
+ var savedOffset int
+ if saveXML.IsValid() {
+ savedOffset = p.savedOffset()
+ }
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ // Sub-element.
+ // Look up by tag name.
+ if sv.IsValid() {
+ k := fieldName(t.Name.Local)
+
+ if fieldPaths != nil {
+ if _, found := fieldPaths[k]; found {
+ if err := p.unmarshalPaths(sv, fieldPaths, k, &t); err != nil {
+ return err
+ }
+ continue Loop
+ }
+ }
+
+ match := func(s string) bool {
+ // check if the name matches ignoring case
+ if strings.ToLower(s) != k {
+ return false
+ }
+ // now check that it's public
+ c, _ := utf8.DecodeRuneInString(s)
+ return unicode.IsUpper(c)
+ }
+
+ f, found := styp.FieldByNameFunc(match)
+ if !found { // fall back to mop-up field named "Any"
+ f, found = styp.FieldByName("Any")
+ }
+ if found {
+ if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil {
+ return err
+ }
+ continue Loop
+ }
+ }
+ // Not saving sub-element but still have to skip over it.
+ if err := p.Skip(); err != nil {
+ return err
+ }
+
+ case EndElement:
+ if saveXML.IsValid() {
+ saveXMLData = p.saved.Bytes()[saveXMLIndex:savedOffset]
+ if saveXMLIndex == 0 {
+ p.saved = nil
+ }
+ }
+ break Loop
+
+ case CharData:
+ if saveData.IsValid() {
+ data = append(data, t...)
+ }
+
+ case Comment:
+ if saveComment.IsValid() {
+ comment = append(comment, t...)
+ }
+ }
+ }
+
+ var err os.Error
+ // Helper functions for integer and unsigned integer conversions
+ var itmp int64
+ getInt64 := func() bool {
+ itmp, err = strconv.Atoi64(string(data))
+ // TODO: should check sizes
+ return err == nil
+ }
+ var utmp uint64
+ getUint64 := func() bool {
+ utmp, err = strconv.Atoui64(string(data))
+ // TODO: check for overflow?
+ return err == nil
+ }
+ var ftmp float64
+ getFloat64 := func() bool {
+ ftmp, err = strconv.Atof64(string(data))
+ // TODO: check for overflow?
+ return err == nil
+ }
+
+ // Save accumulated data and comments
+ switch t := saveData; t.Kind() {
+ case reflect.Invalid:
+ // Probably a comment, handled below
+ default:
+ return os.ErrorString("cannot happen: unknown type " + t.Type().String())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if !getInt64() {
+ return err
+ }
+ t.SetInt(itmp)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ if !getUint64() {
+ return err
+ }
+ t.SetUint(utmp)
+ case reflect.Float32, reflect.Float64:
+ if !getFloat64() {
+ return err
+ }
+ t.SetFloat(ftmp)
+ case reflect.Bool:
+ value, err := strconv.Atob(strings.TrimSpace(string(data)))
+ if err != nil {
+ return err
+ }
+ t.SetBool(value)
+ case reflect.String:
+ t.SetString(string(data))
+ case reflect.Slice:
+ t.Set(reflect.NewValue(data))
+ }
+
+ switch t := saveComment; t.Kind() {
+ case reflect.String:
+ t.SetString(string(comment))
+ case reflect.Slice:
+ t.Set(reflect.NewValue(comment))
+ }
+
+ switch t := saveXML; t.Kind() {
+ case reflect.String:
+ t.SetString(string(saveXMLData))
+ case reflect.Slice:
+ t.Set(reflect.NewValue(saveXMLData))
+ }
+
+ return nil
+}
+
+type pathInfo struct {
+ fieldIdx []int
+ complete bool
+}
+
+// addFieldPath takes an element path such as "a>b>c" and fills the
+// paths map with all paths leading to it ("a", "a>b", and "a>b>c").
+// It is okay for paths to share a common, shorter prefix but not ok
+// for one path to itself be a prefix of another.
+func addFieldPath(sv reflect.Value, paths map[string]pathInfo, path string, fieldIdx []int) os.Error {
+ if info, found := paths[path]; found {
+ return tagError(sv, info.fieldIdx, fieldIdx)
+ }
+ paths[path] = pathInfo{fieldIdx, true}
+ for {
+ i := strings.LastIndex(path, ">")
+ if i < 0 {
+ break
+ }
+ path = path[:i]
+ if info, found := paths[path]; found {
+ if info.complete {
+ return tagError(sv, info.fieldIdx, fieldIdx)
+ }
+ } else {
+ paths[path] = pathInfo{fieldIdx, false}
+ }
+ }
+ return nil
+
+}
+
+func tagError(sv reflect.Value, idx1 []int, idx2 []int) os.Error {
+ t := sv.Type()
+ f1 := t.FieldByIndex(idx1)
+ f2 := t.FieldByIndex(idx2)
+ return &TagPathError{t, f1.Name, f1.Tag, f2.Name, f2.Tag}
+}
+
+// unmarshalPaths walks down an XML structure looking for
+// wanted paths, and calls unmarshal on them.
+func (p *Parser) unmarshalPaths(sv reflect.Value, paths map[string]pathInfo, path string, start *StartElement) os.Error {
+ if info, _ := paths[path]; info.complete {
+ return p.unmarshal(sv.FieldByIndex(info.fieldIdx), start)
+ }
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ k := path + ">" + fieldName(t.Name.Local)
+ if _, found := paths[k]; found {
+ if err := p.unmarshalPaths(sv, paths, k, &t); err != nil {
+ return err
+ }
+ continue
+ }
+ if err := p.Skip(); err != nil {
+ return err
+ }
+ case EndElement:
+ return nil
+ }
+ }
+ panic("unreachable")
+}
+
+// Have already read a start element.
+// Read tokens until we find the end element.
+// Token is taking care of making sure the
+// end element matches the start element we saw.
+func (p *Parser) Skip() os.Error {
+ for {
+ tok, err := p.Token()
+ if err != nil {
+ return err
+ }
+ switch t := tok.(type) {
+ case StartElement:
+ if err := p.Skip(); err != nil {
+ return err
+ }
+ case EndElement:
+ return nil
+ }
+ }
+ panic("unreachable")
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+import (
+ "bytes"
+ "io"
+ "math"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// runeUnreader is the interface to something that can unread runes.
+// If the object provided to Scan does not satisfy this interface,
+// a local buffer will be used to back up the input, but its contents
+// will be lost when Scan returns.
+type runeUnreader interface {
+ UnreadRune() os.Error
+}
+
+// ScanState represents the scanner state passed to custom scanners.
+// Scanners may do rune-at-a-time scanning or ask the ScanState
+// to discover the next space-delimited token.
+type ScanState interface {
+ // ReadRune reads the next rune (Unicode code point) from the input.
+ // If invoked during Scanln, Fscanln, or Sscanln, ReadRune() will
+ // return EOF after returning the first '\n' or when reading beyond
+ // the specified width.
+ ReadRune() (rune int, size int, err os.Error)
+ // UnreadRune causes the next call to ReadRune to return the same rune.
+ UnreadRune() os.Error
+ // Token skips space in the input if skipSpace is true, then returns the
+ // run of Unicode code points c satisfying f(c). If f is nil,
+ // !unicode.IsSpace(c) is used; that is, the token will hold non-space
+ // characters. Newlines are treated as space unless the scan operation
+ // is Scanln, Fscanln or Sscanln, in which case a newline is treated as
+ // EOF. The returned slice points to shared data that may be overwritten
+ // by the next call to Token, a call to a Scan function using the ScanState
+ // as input, or when the calling Scan method returns.
+ Token(skipSpace bool, f func(int) bool) (token []byte, err os.Error)
+ // Width returns the value of the width option and whether it has been set.
+ // The unit is Unicode code points.
+ Width() (wid int, ok bool)
+ // Because ReadRune is implemented by the interface, Read should never be
+ // called by the scanning routines and a valid implementation of
+ // ScanState may choose always to return an error from Read.
+ Read(buf []byte) (n int, err os.Error)
+}
+
+// Scanner is implemented by any value that has a Scan method, which scans
+// the input for the representation of a value and stores the result in the
+// receiver, which must be a pointer to be useful. The Scan method is called
+// for any argument to Scan, Scanf, or Scanln that implements it.
+type Scanner interface {
+ Scan(state ScanState, verb int) os.Error
+}
+
+// Scan scans text read from standard input, storing successive
+// space-separated values into successive arguments. Newlines count
+// as space. It returns the number of items successfully scanned.
+// If that is less than the number of arguments, err will report why.
+func Scan(a ...interface{}) (n int, err os.Error) {
+ return Fscan(os.Stdin, a...)
+}
+
+// Scanln is similar to Scan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Scanln(a ...interface{}) (n int, err os.Error) {
+ return Fscanln(os.Stdin, a...)
+}
+
+// Scanf scans text read from standard input, storing successive
+// space-separated values into successive arguments as determined by
+// the format. It returns the number of items successfully scanned.
+func Scanf(format string, a ...interface{}) (n int, err os.Error) {
+ return Fscanf(os.Stdin, format, a...)
+}
+
+// Sscan scans the argument string, storing successive space-separated
+// values into successive arguments. Newlines count as space. It
+// returns the number of items successfully scanned. If that is less
+// than the number of arguments, err will report why.
+func Sscan(str string, a ...interface{}) (n int, err os.Error) {
+ return Fscan(strings.NewReader(str), a...)
+}
+
+// Sscanln is similar to Sscan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Sscanln(str string, a ...interface{}) (n int, err os.Error) {
+ return Fscanln(strings.NewReader(str), a...)
+}
+
+// Sscanf scans the argument string, storing successive space-separated
+// values into successive arguments as determined by the format. It
+// returns the number of items successfully parsed.
+func Sscanf(str string, format string, a ...interface{}) (n int, err os.Error) {
+ return Fscanf(strings.NewReader(str), format, a...)
+}
+
+// Fscan scans text read from r, storing successive space-separated
+// values into successive arguments. Newlines count as space. It
+// returns the number of items successfully scanned. If that is less
+// than the number of arguments, err will report why.
+func Fscan(r io.Reader, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, true, false)
+ n, err = s.doScan(a)
+ s.free(old)
+ return
+}
+
+// Fscanln is similar to Fscan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Fscanln(r io.Reader, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, false, true)
+ n, err = s.doScan(a)
+ s.free(old)
+ return
+}
+
+// Fscanf scans text read from r, storing successive space-separated
+// values into successive arguments as determined by the format. It
+// returns the number of items successfully parsed.
+func Fscanf(r io.Reader, format string, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, false, false)
+ n, err = s.doScanf(format, a)
+ s.free(old)
+ return
+}
+
+// scanError represents an error generated by the scanning software.
+// It's used as a unique signature to identify such errors when recovering.
+type scanError struct {
+ err os.Error
+}
+
+const eof = -1
+
+// ss is the internal implementation of ScanState.
+type ss struct {
+ rr io.RuneReader // where to read input
+ buf bytes.Buffer // token accumulator
+ peekRune int // one-rune lookahead
+ prevRune int // last rune returned by ReadRune
+ count int // runes consumed so far.
+ atEOF bool // already read EOF
+ ssave
+}
+
+// ssave holds the parts of ss that need to be
+// saved and restored on recursive scans.
+type ssave struct {
+ validSave bool // is or was a part of an actual ss.
+ nlIsEnd bool // whether newline terminates scan
+ nlIsSpace bool // whether newline counts as white space
+ fieldLimit int // max value of ss.count for this field; fieldLimit <= limit
+ limit int // max value of ss.count.
+ maxWid int // width of this field.
+}
+
+// The Read method is only in ScanState so that ScanState
+// satisfies io.Reader. It will never be called when used as
+// intended, so there is no need to make it actually work.
+func (s *ss) Read(buf []byte) (n int, err os.Error) {
+ return 0, os.ErrorString("ScanState's Read should not be called. Use ReadRune")
+}
+
+func (s *ss) ReadRune() (rune int, size int, err os.Error) {
+ if s.peekRune >= 0 {
+ s.count++
+ rune = s.peekRune
+ size = utf8.RuneLen(rune)
+ s.prevRune = rune
+ s.peekRune = -1
+ return
+ }
+ if s.atEOF || s.nlIsEnd && s.prevRune == '\n' || s.count >= s.fieldLimit {
+ err = os.EOF
+ return
+ }
+
+ rune, size, err = s.rr.ReadRune()
+ if err == nil {
+ s.count++
+ s.prevRune = rune
+ } else if err == os.EOF {
+ s.atEOF = true
+ }
+ return
+}
+
+func (s *ss) Width() (wid int, ok bool) {
+ if s.maxWid == hugeWid {
+ return 0, false
+ }
+ return s.maxWid, true
+}
+
+// The public method returns an error; this private one panics.
+// If getRune reaches EOF, the return value is EOF (-1).
+func (s *ss) getRune() (rune int) {
+ rune, _, err := s.ReadRune()
+ if err != nil {
+ if err == os.EOF {
+ return eof
+ }
+ s.error(err)
+ }
+ return
+}
+
+// mustReadRune turns os.EOF into a panic(io.ErrUnexpectedEOF).
+// It is called in cases such as string scanning where an EOF is a
+// syntax error.
+func (s *ss) mustReadRune() (rune int) {
+ rune = s.getRune()
+ if rune == eof {
+ s.error(io.ErrUnexpectedEOF)
+ }
+ return
+}
+
+func (s *ss) UnreadRune() os.Error {
+ if u, ok := s.rr.(runeUnreader); ok {
+ u.UnreadRune()
+ } else {
+ s.peekRune = s.prevRune
+ }
+ s.count--
+ return nil
+}
+
+func (s *ss) error(err os.Error) {
+ panic(scanError{err})
+}
+
+func (s *ss) errorString(err string) {
+ panic(scanError{os.ErrorString(err)})
+}
+
+func (s *ss) Token(skipSpace bool, f func(int) bool) (tok []byte, err os.Error) {
+ defer func() {
+ if e := recover(); e != nil {
+ if se, ok := e.(scanError); ok {
+ err = se.err
+ } else {
+ panic(e)
+ }
+ }
+ }()
+ if f == nil {
+ f = notSpace
+ }
+ s.buf.Reset()
+ tok = s.token(skipSpace, f)
+ return
+}
+
+// notSpace is the default scanning function used in Token.
+func notSpace(r int) bool {
+ return !unicode.IsSpace(r)
+}
+
+// readRune is a structure to enable reading UTF-8 encoded code points
+// from an io.Reader. It is used if the Reader given to the scanner does
+// not already implement io.RuneReader.
+type readRune struct {
+ reader io.Reader
+ buf [utf8.UTFMax]byte // used only inside ReadRune
+ pending int // number of bytes in pendBuf; only >0 for bad UTF-8
+ pendBuf [utf8.UTFMax]byte // bytes left over
+}
+
+// readByte returns the next byte from the input, which may be
+// left over from a previous read if the UTF-8 was ill-formed.
+func (r *readRune) readByte() (b byte, err os.Error) {
+ if r.pending > 0 {
+ b = r.pendBuf[0]
+ copy(r.pendBuf[0:], r.pendBuf[1:])
+ r.pending--
+ return
+ }
+ _, err = r.reader.Read(r.pendBuf[0:1])
+ return r.pendBuf[0], err
+}
+
+// unread saves the bytes for the next read.
+func (r *readRune) unread(buf []byte) {
+ copy(r.pendBuf[r.pending:], buf)
+ r.pending += len(buf)
+}
+
+// ReadRune returns the next UTF-8 encoded code point from the
+// io.Reader inside r.
+func (r *readRune) ReadRune() (rune int, size int, err os.Error) {
+ r.buf[0], err = r.readByte()
+ if err != nil {
+ return 0, 0, err
+ }
+ if r.buf[0] < utf8.RuneSelf { // fast check for common ASCII case
+ rune = int(r.buf[0])
+ return
+ }
+ var n int
+ for n = 1; !utf8.FullRune(r.buf[0:n]); n++ {
+ r.buf[n], err = r.readByte()
+ if err != nil {
+ if err == os.EOF {
+ err = nil
+ break
+ }
+ return
+ }
+ }
+ rune, size = utf8.DecodeRune(r.buf[0:n])
+ if size < n { // an error
+ r.unread(r.buf[size:n])
+ }
+ return
+}
+
+
+var ssFree = newCache(func() interface{} { return new(ss) })
+
+// Allocate a new ss struct or grab a cached one.
+func newScanState(r io.Reader, nlIsSpace, nlIsEnd bool) (s *ss, old ssave) {
+ // If the reader is a *ss, then we've got a recursive
+ // call to Scan, so re-use the scan state.
+ s, ok := r.(*ss)
+ if ok {
+ old = s.ssave
+ s.limit = s.fieldLimit
+ s.nlIsEnd = nlIsEnd || s.nlIsEnd
+ s.nlIsSpace = nlIsSpace
+ return
+ }
+
+ s = ssFree.get().(*ss)
+ if rr, ok := r.(io.RuneReader); ok {
+ s.rr = rr
+ } else {
+ s.rr = &readRune{reader: r}
+ }
+ s.nlIsSpace = nlIsSpace
+ s.nlIsEnd = nlIsEnd
+ s.prevRune = -1
+ s.peekRune = -1
+ s.atEOF = false
+ s.limit = hugeWid
+ s.fieldLimit = hugeWid
+ s.maxWid = hugeWid
+ s.validSave = true
+ return
+}
+
+// Save used ss structs in ssFree; avoid an allocation per invocation.
+func (s *ss) free(old ssave) {
+ // If it was used recursively, just restore the old state.
+ if old.validSave {
+ s.ssave = old
+ return
+ }
+ // Don't hold on to ss structs with large buffers.
+ if cap(s.buf.Bytes()) > 1024 {
+ return
+ }
+ s.buf.Reset()
+ s.rr = nil
+ ssFree.put(s)
+}
+
+// skipSpace skips spaces and maybe newlines.
+func (s *ss) skipSpace(stopAtNewline bool) {
+ for {
+ rune := s.getRune()
+ if rune == eof {
+ return
+ }
+ if rune == '\n' {
+ if stopAtNewline {
+ break
+ }
+ if s.nlIsSpace {
+ continue
+ }
+ s.errorString("unexpected newline")
+ return
+ }
+ if !unicode.IsSpace(rune) {
+ s.UnreadRune()
+ break
+ }
+ }
+}
+
+
+// token returns the next space-delimited string from the input. It
+// skips white space. For Scanln, it stops at newlines. For Scan,
+// newlines are treated as spaces.
+func (s *ss) token(skipSpace bool, f func(int) bool) []byte {
+ if skipSpace {
+ s.skipSpace(false)
+ }
+ // read until white space or newline
+ for {
+ rune := s.getRune()
+ if rune == eof {
+ break
+ }
+ if !f(rune) {
+ s.UnreadRune()
+ break
+ }
+ s.buf.WriteRune(rune)
+ }
+ return s.buf.Bytes()
+}
+
+// typeError indicates that the type of the operand did not match the format
+func (s *ss) typeError(field interface{}, expected string) {
+ s.errorString("expected field of type pointer to " + expected + "; found " + reflect.Typeof(field).String())
+}
+
+var complexError = os.ErrorString("syntax error scanning complex number")
+var boolError = os.ErrorString("syntax error scanning boolean")
+
+// consume reads the next rune in the input and reports whether it is in the ok string.
+// If accept is true, it puts the character into the input token.
+func (s *ss) consume(ok string, accept bool) bool {
+ rune := s.getRune()
+ if rune == eof {
+ return false
+ }
+ if strings.IndexRune(ok, rune) >= 0 {
+ if accept {
+ s.buf.WriteRune(rune)
+ }
+ return true
+ }
+ if rune != eof && accept {
+ s.UnreadRune()
+ }
+ return false
+}
+
+// peek reports whether the next character is in the ok string, without consuming it.
+func (s *ss) peek(ok string) bool {
+ rune := s.getRune()
+ if rune != eof {
+ s.UnreadRune()
+ }
+ return strings.IndexRune(ok, rune) >= 0
+}
+
+// accept checks the next rune in the input. If it's a byte (sic) in the string, it puts it in the
+// buffer and returns true. Otherwise it return false.
+func (s *ss) accept(ok string) bool {
+ return s.consume(ok, true)
+}
+
+// okVerb verifies that the verb is present in the list, setting s.err appropriately if not.
+func (s *ss) okVerb(verb int, okVerbs, typ string) bool {
+ for _, v := range okVerbs {
+ if v == verb {
+ return true
+ }
+ }
+ s.errorString("bad verb %" + string(verb) + " for " + typ)
+ return false
+}
+
+// scanBool returns the value of the boolean represented by the next token.
+func (s *ss) scanBool(verb int) bool {
+ if !s.okVerb(verb, "tv", "boolean") {
+ return false
+ }
+ // Syntax-checking a boolean is annoying. We're not fastidious about case.
+ switch s.mustReadRune() {
+ case '0':
+ return false
+ case '1':
+ return true
+ case 't', 'T':
+ if s.accept("rR") && (!s.accept("uU") || !s.accept("eE")) {
+ s.error(boolError)
+ }
+ return true
+ case 'f', 'F':
+ if s.accept("aL") && (!s.accept("lL") || !s.accept("sS") || !s.accept("eE")) {
+ s.error(boolError)
+ }
+ return false
+ }
+ return false
+}
+
+// Numerical elements
+const (
+ binaryDigits = "01"
+ octalDigits = "01234567"
+ decimalDigits = "0123456789"
+ hexadecimalDigits = "0123456789aAbBcCdDeEfF"
+ sign = "+-"
+ period = "."
+ exponent = "eEp"
+)
+
+// getBase returns the numeric base represented by the verb and its digit string.
+func (s *ss) getBase(verb int) (base int, digits string) {
+ s.okVerb(verb, "bdoUxXv", "integer") // sets s.err
+ base = 10
+ digits = decimalDigits
+ switch verb {
+ case 'b':
+ base = 2
+ digits = binaryDigits
+ case 'o':
+ base = 8
+ digits = octalDigits
+ case 'x', 'X', 'U':
+ base = 16
+ digits = hexadecimalDigits
+ }
+ return
+}
+
+// scanNumber returns the numerical string with specified digits starting here.
+func (s *ss) scanNumber(digits string, haveDigits bool) string {
+ if !haveDigits && !s.accept(digits) {
+ s.errorString("expected integer")
+ }
+ for s.accept(digits) {
+ }
+ return s.buf.String()
+}
+
+// scanRune returns the next rune value in the input.
+func (s *ss) scanRune(bitSize int) int64 {
+ rune := int64(s.mustReadRune())
+ n := uint(bitSize)
+ x := (rune << (64 - n)) >> (64 - n)
+ if x != rune {
+ s.errorString("overflow on character value " + string(rune))
+ }
+ return rune
+}
+
+// scanBasePrefix reports whether the integer begins with a 0 or 0x,
+// and returns the base, digit string, and whether a zero was found.
+// It is called only if the verb is %v.
+func (s *ss) scanBasePrefix() (base int, digits string, found bool) {
+ if !s.peek("0") {
+ return 10, decimalDigits, false
+ }
+ s.accept("0")
+ found = true // We've put a digit into the token buffer.
+ // Special cases for '0' && '0x'
+ base, digits = 8, octalDigits
+ if s.peek("xX") {
+ s.consume("xX", false)
+ base, digits = 16, hexadecimalDigits
+ }
+ return
+}
+
+// scanInt returns the value of the integer represented by the next
+// token, checking for overflow. Any error is stored in s.err.
+func (s *ss) scanInt(verb int, bitSize int) int64 {
+ if verb == 'c' {
+ return s.scanRune(bitSize)
+ }
+ s.skipSpace(false)
+ base, digits := s.getBase(verb)
+ haveDigits := false
+ if verb == 'U' {
+ if !s.consume("U", false) || !s.consume("+", false) {
+ s.errorString("bad unicode format ")
+ }
+ } else {
+ s.accept(sign) // If there's a sign, it will be left in the token buffer.
+ if verb == 'v' {
+ base, digits, haveDigits = s.scanBasePrefix()
+ }
+ }
+ tok := s.scanNumber(digits, haveDigits)
+ i, err := strconv.Btoi64(tok, base)
+ if err != nil {
+ s.error(err)
+ }
+ n := uint(bitSize)
+ x := (i << (64 - n)) >> (64 - n)
+ if x != i {
+ s.errorString("integer overflow on token " + tok)
+ }
+ return i
+}
+
+// scanUint returns the value of the unsigned integer represented
+// by the next token, checking for overflow. Any error is stored in s.err.
+func (s *ss) scanUint(verb int, bitSize int) uint64 {
+ if verb == 'c' {
+ return uint64(s.scanRune(bitSize))
+ }
+ s.skipSpace(false)
+ base, digits := s.getBase(verb)
+ haveDigits := false
+ if verb == 'U' {
+ if !s.consume("U", false) || !s.consume("+", false) {
+ s.errorString("bad unicode format ")
+ }
+ } else if verb == 'v' {
+ base, digits, haveDigits = s.scanBasePrefix()
+ }
+ tok := s.scanNumber(digits, haveDigits)
+ i, err := strconv.Btoui64(tok, base)
+ if err != nil {
+ s.error(err)
+ }
+ n := uint(bitSize)
+ x := (i << (64 - n)) >> (64 - n)
+ if x != i {
+ s.errorString("unsigned integer overflow on token " + tok)
+ }
+ return i
+}
+
+// floatToken returns the floating-point number starting here, no longer than swid
+// if the width is specified. It's not rigorous about syntax because it doesn't check that
+// we have at least some digits, but Atof will do that.
+func (s *ss) floatToken() string {
+ s.buf.Reset()
+ // NaN?
+ if s.accept("nN") && s.accept("aA") && s.accept("nN") {
+ return s.buf.String()
+ }
+ // leading sign?
+ s.accept(sign)
+ // Inf?
+ if s.accept("iI") && s.accept("nN") && s.accept("fF") {
+ return s.buf.String()
+ }
+ // digits?
+ for s.accept(decimalDigits) {
+ }
+ // decimal point?
+ if s.accept(period) {
+ // fraction?
+ for s.accept(decimalDigits) {
+ }
+ }
+ // exponent?
+ if s.accept(exponent) {
+ // leading sign?
+ s.accept(sign)
+ // digits?
+ for s.accept(decimalDigits) {
+ }
+ }
+ return s.buf.String()
+}
+
+// complexTokens returns the real and imaginary parts of the complex number starting here.
+// The number might be parenthesized and has the format (N+Ni) where N is a floating-point
+// number and there are no spaces within.
+func (s *ss) complexTokens() (real, imag string) {
+ // TODO: accept N and Ni independently?
+ parens := s.accept("(")
+ real = s.floatToken()
+ s.buf.Reset()
+ // Must now have a sign.
+ if !s.accept("+-") {
+ s.error(complexError)
+ }
+ // Sign is now in buffer
+ imagSign := s.buf.String()
+ imag = s.floatToken()
+ if !s.accept("i") {
+ s.error(complexError)
+ }
+ if parens && !s.accept(")") {
+ s.error(complexError)
+ }
+ return real, imagSign + imag
+}
+
+// convertFloat converts the string to a float64value.
+func (s *ss) convertFloat(str string, n int) float64 {
+ if p := strings.Index(str, "p"); p >= 0 {
+ // Atof doesn't handle power-of-2 exponents,
+ // but they're easy to evaluate.
+ f, err := strconv.AtofN(str[:p], n)
+ if err != nil {
+ // Put full string into error.
+ if e, ok := err.(*strconv.NumError); ok {
+ e.Num = str
+ }
+ s.error(err)
+ }
+ n, err := strconv.Atoi(str[p+1:])
+ if err != nil {
+ // Put full string into error.
+ if e, ok := err.(*strconv.NumError); ok {
+ e.Num = str
+ }
+ s.error(err)
+ }
+ return math.Ldexp(f, n)
+ }
+ f, err := strconv.AtofN(str, n)
+ if err != nil {
+ s.error(err)
+ }
+ return f
+}
+
+// convertComplex converts the next token to a complex128 value.
+// The atof argument is a type-specific reader for the underlying type.
+// If we're reading complex64, atof will parse float32s and convert them
+// to float64's to avoid reproducing this code for each complex type.
+func (s *ss) scanComplex(verb int, n int) complex128 {
+ if !s.okVerb(verb, floatVerbs, "complex") {
+ return 0
+ }
+ s.skipSpace(false)
+ sreal, simag := s.complexTokens()
+ real := s.convertFloat(sreal, n/2)
+ imag := s.convertFloat(simag, n/2)
+ return complex(real, imag)
+}
+
+// convertString returns the string represented by the next input characters.
+// The format of the input is determined by the verb.
+func (s *ss) convertString(verb int) (str string) {
+ if !s.okVerb(verb, "svqx", "string") {
+ return ""
+ }
+ s.skipSpace(false)
+ switch verb {
+ case 'q':
+ str = s.quotedString()
+ case 'x':
+ str = s.hexString()
+ default:
+ str = string(s.token(true, notSpace)) // %s and %v just return the next word
+ }
+ // Empty strings other than with %q are not OK.
+ if len(str) == 0 && verb != 'q' && s.maxWid > 0 {
+ s.errorString("Scan: no data for string")
+ }
+ return
+}
+
+// quotedString returns the double- or back-quoted string represented by the next input characters.
+func (s *ss) quotedString() string {
+ quote := s.mustReadRune()
+ switch quote {
+ case '`':
+ // Back-quoted: Anything goes until EOF or back quote.
+ for {
+ rune := s.mustReadRune()
+ if rune == quote {
+ break
+ }
+ s.buf.WriteRune(rune)
+ }
+ return s.buf.String()
+ case '"':
+ // Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes.
+ s.buf.WriteRune(quote)
+ for {
+ rune := s.mustReadRune()
+ s.buf.WriteRune(rune)
+ if rune == '\\' {
+ // In a legal backslash escape, no matter how long, only the character
+ // immediately after the escape can itself be a backslash or quote.
+ // Thus we only need to protect the first character after the backslash.
+ rune := s.mustReadRune()
+ s.buf.WriteRune(rune)
+ } else if rune == '"' {
+ break
+ }
+ }
+ result, err := strconv.Unquote(s.buf.String())
+ if err != nil {
+ s.error(err)
+ }
+ return result
+ default:
+ s.errorString("expected quoted string")
+ }
+ return ""
+}
+
+// hexDigit returns the value of the hexadecimal digit
+func (s *ss) hexDigit(digit int) int {
+ switch digit {
+ case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
+ return digit - '0'
+ case 'a', 'b', 'c', 'd', 'e', 'f':
+ return 10 + digit - 'a'
+ case 'A', 'B', 'C', 'D', 'E', 'F':
+ return 10 + digit - 'A'
+ }
+ s.errorString("Scan: illegal hex digit")
+ return 0
+}
+
+// hexByte returns the next hex-encoded (two-character) byte from the input.
+// There must be either two hexadecimal digits or a space character in the input.
+func (s *ss) hexByte() (b byte, ok bool) {
+ rune1 := s.getRune()
+ if rune1 == eof {
+ return
+ }
+ if unicode.IsSpace(rune1) {
+ s.UnreadRune()
+ return
+ }
+ rune2 := s.mustReadRune()
+ return byte(s.hexDigit(rune1)<<4 | s.hexDigit(rune2)), true
+}
+
+// hexString returns the space-delimited hexpair-encoded string.
+func (s *ss) hexString() string {
+ for {
+ b, ok := s.hexByte()
+ if !ok {
+ break
+ }
+ s.buf.WriteByte(b)
+ }
+ if s.buf.Len() == 0 {
+ s.errorString("Scan: no hex data for %x string")
+ return ""
+ }
+ return s.buf.String()
+}
+
+const floatVerbs = "beEfFgGv"
+
+const hugeWid = 1 << 30
+
+// scanOne scans a single value, deriving the scanner from the type of the argument.
+func (s *ss) scanOne(verb int, field interface{}) {
+ s.buf.Reset()
+ var err os.Error
+ // If the parameter has its own Scan method, use that.
+ if v, ok := field.(Scanner); ok {
+ err = v.Scan(s, verb)
+ if err != nil {
+ if err == os.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ s.error(err)
+ }
+ return
+ }
+ switch v := field.(type) {
+ case *bool:
+ *v = s.scanBool(verb)
+ case *complex64:
+ *v = complex64(s.scanComplex(verb, 64))
+ case *complex128:
+ *v = s.scanComplex(verb, 128)
+ case *int:
+ *v = int(s.scanInt(verb, intBits))
+ case *int8:
+ *v = int8(s.scanInt(verb, 8))
+ case *int16:
+ *v = int16(s.scanInt(verb, 16))
+ case *int32:
+ *v = int32(s.scanInt(verb, 32))
+ case *int64:
+ *v = s.scanInt(verb, 64)
+ case *uint:
+ *v = uint(s.scanUint(verb, intBits))
+ case *uint8:
+ *v = uint8(s.scanUint(verb, 8))
+ case *uint16:
+ *v = uint16(s.scanUint(verb, 16))
+ case *uint32:
+ *v = uint32(s.scanUint(verb, 32))
+ case *uint64:
+ *v = s.scanUint(verb, 64)
+ case *uintptr:
+ *v = uintptr(s.scanUint(verb, uintptrBits))
+ // Floats are tricky because you want to scan in the precision of the result, not
+ // scan in high precision and convert, in order to preserve the correct error condition.
+ case *float32:
+ if s.okVerb(verb, floatVerbs, "float32") {
+ s.skipSpace(false)
+ *v = float32(s.convertFloat(s.floatToken(), 32))
+ }
+ case *float64:
+ if s.okVerb(verb, floatVerbs, "float64") {
+ s.skipSpace(false)
+ *v = s.convertFloat(s.floatToken(), 64)
+ }
+ case *string:
+ *v = s.convertString(verb)
+ case *[]byte:
+ // We scan to string and convert so we get a copy of the data.
+ // If we scanned to bytes, the slice would point at the buffer.
+ *v = []byte(s.convertString(verb))
+ default:
+ val := reflect.NewValue(v)
+ ptr, ok := val.(*reflect.PtrValue)
+ if !ok {
+ s.errorString("Scan: type not a pointer: " + val.Type().String())
+ return
+ }
+ switch v := ptr.Elem().(type) {
+ case *reflect.BoolValue:
+ v.Set(s.scanBool(verb))
+ case *reflect.IntValue:
+ v.Set(s.scanInt(verb, v.Type().Bits()))
+ case *reflect.UintValue:
+ v.Set(s.scanUint(verb, v.Type().Bits()))
+ case *reflect.StringValue:
+ v.Set(s.convertString(verb))
+ case *reflect.SliceValue:
+ // For now, can only handle (renamed) []byte.
+ typ := v.Type().(*reflect.SliceType)
+ if typ.Elem().Kind() != reflect.Uint8 {
+ goto CantHandle
+ }
+ str := s.convertString(verb)
+ v.Set(reflect.MakeSlice(typ, len(str), len(str)))
+ for i := 0; i < len(str); i++ {
+ v.Elem(i).(*reflect.UintValue).Set(uint64(str[i]))
+ }
+ case *reflect.FloatValue:
+ s.skipSpace(false)
+ v.Set(s.convertFloat(s.floatToken(), v.Type().Bits()))
+ case *reflect.ComplexValue:
+ v.Set(s.scanComplex(verb, v.Type().Bits()))
+ default:
+ CantHandle:
+ s.errorString("Scan: can't handle type: " + val.Type().String())
+ }
+ }
+}
+
+// errorHandler turns local panics into error returns. EOFs are benign.
+func errorHandler(errp *os.Error) {
+ if e := recover(); e != nil {
+ if se, ok := e.(scanError); ok { // catch local error
+ if se.err != os.EOF {
+ *errp = se.err
+ }
+ } else {
+ panic(e)
+ }
+ }
+}
+
+// doScan does the real work for scanning without a format string.
+func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) {
+ defer errorHandler(&err)
+ for _, field := range a {
+ s.scanOne('v', field)
+ numProcessed++
+ }
+ // Check for newline if required.
+ if !s.nlIsSpace {
+ for {
+ rune := s.getRune()
+ if rune == '\n' || rune == eof {
+ break
+ }
+ if !unicode.IsSpace(rune) {
+ s.errorString("Scan: expected newline")
+ break
+ }
+ }
+ }
+ return
+}
+
+// advance determines whether the next characters in the input match
+// those of the format. It returns the number of bytes (sic) consumed
+// in the format. Newlines included, all runs of space characters in
+// either input or format behave as a single space. This routine also
+// handles the %% case. If the return value is zero, either format
+// starts with a % (with no following %) or the input is empty.
+// If it is negative, the input did not match the string.
+func (s *ss) advance(format string) (i int) {
+ for i < len(format) {
+ fmtc, w := utf8.DecodeRuneInString(format[i:])
+ if fmtc == '%' {
+ // %% acts like a real percent
+ nextc, _ := utf8.DecodeRuneInString(format[i+w:]) // will not match % if string is empty
+ if nextc != '%' {
+ return
+ }
+ i += w // skip the first %
+ }
+ sawSpace := false
+ for unicode.IsSpace(fmtc) && i < len(format) {
+ sawSpace = true
+ i += w
+ fmtc, w = utf8.DecodeRuneInString(format[i:])
+ }
+ if sawSpace {
+ // There was space in the format, so there should be space (EOF)
+ // in the input.
+ inputc := s.getRune()
+ if inputc == eof {
+ return
+ }
+ if !unicode.IsSpace(inputc) {
+ // Space in format but not in input: error
+ s.errorString("expected space in input to match format")
+ }
+ s.skipSpace(true)
+ continue
+ }
+ inputc := s.mustReadRune()
+ if fmtc != inputc {
+ s.UnreadRune()
+ return -1
+ }
+ i += w
+ }
+ return
+}
+
+// doScanf does the real work when scanning with a format string.
+// At the moment, it handles only pointers to basic types.
+func (s *ss) doScanf(format string, a []interface{}) (numProcessed int, err os.Error) {
+ defer errorHandler(&err)
+ end := len(format) - 1
+ // We process one item per non-trivial format
+ for i := 0; i <= end; {
+ w := s.advance(format[i:])
+ if w > 0 {
+ i += w
+ continue
+ }
+ // Either we failed to advance, we have a percent character, or we ran out of input.
+ if format[i] != '%' {
+ // Can't advance format. Why not?
+ if w < 0 {
+ s.errorString("input does not match format")
+ }
+ // Otherwise at EOF; "too many operands" error handled below
+ break
+ }
+ i++ // % is one byte
+
+ // do we have 20 (width)?
+ var widPresent bool
+ s.maxWid, widPresent, i = parsenum(format, i, end)
+ if !widPresent {
+ s.maxWid = hugeWid
+ }
+ s.fieldLimit = s.limit
+ if f := s.count + s.maxWid; f < s.fieldLimit {
+ s.fieldLimit = f
+ }
+
+ c, w := utf8.DecodeRuneInString(format[i:])
+ i += w
+
+ if numProcessed >= len(a) { // out of operands
+ s.errorString("too few operands for format %" + format[i-w:])
+ break
+ }
+ field := a[numProcessed]
+
+ s.scanOne(c, field)
+ numProcessed++
+ s.fieldLimit = s.limit
+ }
+ if numProcessed < len(a) {
+ s.errorString("too many operands")
+ }
+ return
+}
--- /dev/null
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+import (
+ "bytes"
+ "io"
+ "math"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// runeUnreader is the interface to something that can unread runes.
+// If the object provided to Scan does not satisfy this interface,
+// a local buffer will be used to back up the input, but its contents
+// will be lost when Scan returns.
+type runeUnreader interface {
+ UnreadRune() os.Error
+}
+
+// ScanState represents the scanner state passed to custom scanners.
+// Scanners may do rune-at-a-time scanning or ask the ScanState
+// to discover the next space-delimited token.
+type ScanState interface {
+ // ReadRune reads the next rune (Unicode code point) from the input.
+ // If invoked during Scanln, Fscanln, or Sscanln, ReadRune() will
+ // return EOF after returning the first '\n' or when reading beyond
+ // the specified width.
+ ReadRune() (rune int, size int, err os.Error)
+ // UnreadRune causes the next call to ReadRune to return the same rune.
+ UnreadRune() os.Error
+ // Token skips space in the input if skipSpace is true, then returns the
+ // run of Unicode code points c satisfying f(c). If f is nil,
+ // !unicode.IsSpace(c) is used; that is, the token will hold non-space
+ // characters. Newlines are treated as space unless the scan operation
+ // is Scanln, Fscanln or Sscanln, in which case a newline is treated as
+ // EOF. The returned slice points to shared data that may be overwritten
+ // by the next call to Token, a call to a Scan function using the ScanState
+ // as input, or when the calling Scan method returns.
+ Token(skipSpace bool, f func(int) bool) (token []byte, err os.Error)
+ // Width returns the value of the width option and whether it has been set.
+ // The unit is Unicode code points.
+ Width() (wid int, ok bool)
+ // Because ReadRune is implemented by the interface, Read should never be
+ // called by the scanning routines and a valid implementation of
+ // ScanState may choose always to return an error from Read.
+ Read(buf []byte) (n int, err os.Error)
+}
+
+// Scanner is implemented by any value that has a Scan method, which scans
+// the input for the representation of a value and stores the result in the
+// receiver, which must be a pointer to be useful. The Scan method is called
+// for any argument to Scan, Scanf, or Scanln that implements it.
+type Scanner interface {
+ Scan(state ScanState, verb int) os.Error
+}
+
+// Scan scans text read from standard input, storing successive
+// space-separated values into successive arguments. Newlines count
+// as space. It returns the number of items successfully scanned.
+// If that is less than the number of arguments, err will report why.
+func Scan(a ...interface{}) (n int, err os.Error) {
+ return Fscan(os.Stdin, a...)
+}
+
+// Scanln is similar to Scan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Scanln(a ...interface{}) (n int, err os.Error) {
+ return Fscanln(os.Stdin, a...)
+}
+
+// Scanf scans text read from standard input, storing successive
+// space-separated values into successive arguments as determined by
+// the format. It returns the number of items successfully scanned.
+func Scanf(format string, a ...interface{}) (n int, err os.Error) {
+ return Fscanf(os.Stdin, format, a...)
+}
+
+// Sscan scans the argument string, storing successive space-separated
+// values into successive arguments. Newlines count as space. It
+// returns the number of items successfully scanned. If that is less
+// than the number of arguments, err will report why.
+func Sscan(str string, a ...interface{}) (n int, err os.Error) {
+ return Fscan(strings.NewReader(str), a...)
+}
+
+// Sscanln is similar to Sscan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Sscanln(str string, a ...interface{}) (n int, err os.Error) {
+ return Fscanln(strings.NewReader(str), a...)
+}
+
+// Sscanf scans the argument string, storing successive space-separated
+// values into successive arguments as determined by the format. It
+// returns the number of items successfully parsed.
+func Sscanf(str string, format string, a ...interface{}) (n int, err os.Error) {
+ return Fscanf(strings.NewReader(str), format, a...)
+}
+
+// Fscan scans text read from r, storing successive space-separated
+// values into successive arguments. Newlines count as space. It
+// returns the number of items successfully scanned. If that is less
+// than the number of arguments, err will report why.
+func Fscan(r io.Reader, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, true, false)
+ n, err = s.doScan(a)
+ s.free(old)
+ return
+}
+
+// Fscanln is similar to Fscan, but stops scanning at a newline and
+// after the final item there must be a newline or EOF.
+func Fscanln(r io.Reader, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, false, true)
+ n, err = s.doScan(a)
+ s.free(old)
+ return
+}
+
+// Fscanf scans text read from r, storing successive space-separated
+// values into successive arguments as determined by the format. It
+// returns the number of items successfully parsed.
+func Fscanf(r io.Reader, format string, a ...interface{}) (n int, err os.Error) {
+ s, old := newScanState(r, false, false)
+ n, err = s.doScanf(format, a)
+ s.free(old)
+ return
+}
+
+// scanError represents an error generated by the scanning software.
+// It's used as a unique signature to identify such errors when recovering.
+type scanError struct {
+ err os.Error
+}
+
+const eof = -1
+
+// ss is the internal implementation of ScanState.
+type ss struct {
+ rr io.RuneReader // where to read input
+ buf bytes.Buffer // token accumulator
+ peekRune int // one-rune lookahead
+ prevRune int // last rune returned by ReadRune
+ count int // runes consumed so far.
+ atEOF bool // already read EOF
+ ssave
+}
+
+// ssave holds the parts of ss that need to be
+// saved and restored on recursive scans.
+type ssave struct {
+ validSave bool // is or was a part of an actual ss.
+ nlIsEnd bool // whether newline terminates scan
+ nlIsSpace bool // whether newline counts as white space
+ fieldLimit int // max value of ss.count for this field; fieldLimit <= limit
+ limit int // max value of ss.count.
+ maxWid int // width of this field.
+}
+
+// The Read method is only in ScanState so that ScanState
+// satisfies io.Reader. It will never be called when used as
+// intended, so there is no need to make it actually work.
+func (s *ss) Read(buf []byte) (n int, err os.Error) {
+ return 0, os.ErrorString("ScanState's Read should not be called. Use ReadRune")
+}
+
+func (s *ss) ReadRune() (rune int, size int, err os.Error) {
+ if s.peekRune >= 0 {
+ s.count++
+ rune = s.peekRune
+ size = utf8.RuneLen(rune)
+ s.prevRune = rune
+ s.peekRune = -1
+ return
+ }
+ if s.atEOF || s.nlIsEnd && s.prevRune == '\n' || s.count >= s.fieldLimit {
+ err = os.EOF
+ return
+ }
+
+ rune, size, err = s.rr.ReadRune()
+ if err == nil {
+ s.count++
+ s.prevRune = rune
+ } else if err == os.EOF {
+ s.atEOF = true
+ }
+ return
+}
+
+func (s *ss) Width() (wid int, ok bool) {
+ if s.maxWid == hugeWid {
+ return 0, false
+ }
+ return s.maxWid, true
+}
+
+// The public method returns an error; this private one panics.
+// If getRune reaches EOF, the return value is EOF (-1).
+func (s *ss) getRune() (rune int) {
+ rune, _, err := s.ReadRune()
+ if err != nil {
+ if err == os.EOF {
+ return eof
+ }
+ s.error(err)
+ }
+ return
+}
+
+// mustReadRune turns os.EOF into a panic(io.ErrUnexpectedEOF).
+// It is called in cases such as string scanning where an EOF is a
+// syntax error.
+func (s *ss) mustReadRune() (rune int) {
+ rune = s.getRune()
+ if rune == eof {
+ s.error(io.ErrUnexpectedEOF)
+ }
+ return
+}
+
+func (s *ss) UnreadRune() os.Error {
+ if u, ok := s.rr.(runeUnreader); ok {
+ u.UnreadRune()
+ } else {
+ s.peekRune = s.prevRune
+ }
+ s.count--
+ return nil
+}
+
+func (s *ss) error(err os.Error) {
+ panic(scanError{err})
+}
+
+func (s *ss) errorString(err string) {
+ panic(scanError{os.ErrorString(err)})
+}
+
+func (s *ss) Token(skipSpace bool, f func(int) bool) (tok []byte, err os.Error) {
+ defer func() {
+ if e := recover(); e != nil {
+ if se, ok := e.(scanError); ok {
+ err = se.err
+ } else {
+ panic(e)
+ }
+ }
+ }()
+ if f == nil {
+ f = notSpace
+ }
+ s.buf.Reset()
+ tok = s.token(skipSpace, f)
+ return
+}
+
+// notSpace is the default scanning function used in Token.
+func notSpace(r int) bool {
+ return !unicode.IsSpace(r)
+}
+
+// readRune is a structure to enable reading UTF-8 encoded code points
+// from an io.Reader. It is used if the Reader given to the scanner does
+// not already implement io.RuneReader.
+type readRune struct {
+ reader io.Reader
+ buf [utf8.UTFMax]byte // used only inside ReadRune
+ pending int // number of bytes in pendBuf; only >0 for bad UTF-8
+ pendBuf [utf8.UTFMax]byte // bytes left over
+}
+
+// readByte returns the next byte from the input, which may be
+// left over from a previous read if the UTF-8 was ill-formed.
+func (r *readRune) readByte() (b byte, err os.Error) {
+ if r.pending > 0 {
+ b = r.pendBuf[0]
+ copy(r.pendBuf[0:], r.pendBuf[1:])
+ r.pending--
+ return
+ }
+ _, err = r.reader.Read(r.pendBuf[0:1])
+ return r.pendBuf[0], err
+}
+
+// unread saves the bytes for the next read.
+func (r *readRune) unread(buf []byte) {
+ copy(r.pendBuf[r.pending:], buf)
+ r.pending += len(buf)
+}
+
+// ReadRune returns the next UTF-8 encoded code point from the
+// io.Reader inside r.
+func (r *readRune) ReadRune() (rune int, size int, err os.Error) {
+ r.buf[0], err = r.readByte()
+ if err != nil {
+ return 0, 0, err
+ }
+ if r.buf[0] < utf8.RuneSelf { // fast check for common ASCII case
+ rune = int(r.buf[0])
+ return
+ }
+ var n int
+ for n = 1; !utf8.FullRune(r.buf[0:n]); n++ {
+ r.buf[n], err = r.readByte()
+ if err != nil {
+ if err == os.EOF {
+ err = nil
+ break
+ }
+ return
+ }
+ }
+ rune, size = utf8.DecodeRune(r.buf[0:n])
+ if size < n { // an error
+ r.unread(r.buf[size:n])
+ }
+ return
+}
+
+
+var ssFree = newCache(func() interface{} { return new(ss) })
+
+// Allocate a new ss struct or grab a cached one.
+func newScanState(r io.Reader, nlIsSpace, nlIsEnd bool) (s *ss, old ssave) {
+ // If the reader is a *ss, then we've got a recursive
+ // call to Scan, so re-use the scan state.
+ s, ok := r.(*ss)
+ if ok {
+ old = s.ssave
+ s.limit = s.fieldLimit
+ s.nlIsEnd = nlIsEnd || s.nlIsEnd
+ s.nlIsSpace = nlIsSpace
+ return
+ }
+
+ s = ssFree.get().(*ss)
+ if rr, ok := r.(io.RuneReader); ok {
+ s.rr = rr
+ } else {
+ s.rr = &readRune{reader: r}
+ }
+ s.nlIsSpace = nlIsSpace
+ s.nlIsEnd = nlIsEnd
+ s.prevRune = -1
+ s.peekRune = -1
+ s.atEOF = false
+ s.limit = hugeWid
+ s.fieldLimit = hugeWid
+ s.maxWid = hugeWid
+ s.validSave = true
+ return
+}
+
+// Save used ss structs in ssFree; avoid an allocation per invocation.
+func (s *ss) free(old ssave) {
+ // If it was used recursively, just restore the old state.
+ if old.validSave {
+ s.ssave = old
+ return
+ }
+ // Don't hold on to ss structs with large buffers.
+ if cap(s.buf.Bytes()) > 1024 {
+ return
+ }
+ s.buf.Reset()
+ s.rr = nil
+ ssFree.put(s)
+}
+
+// skipSpace skips spaces and maybe newlines.
+func (s *ss) skipSpace(stopAtNewline bool) {
+ for {
+ rune := s.getRune()
+ if rune == eof {
+ return
+ }
+ if rune == '\n' {
+ if stopAtNewline {
+ break
+ }
+ if s.nlIsSpace {
+ continue
+ }
+ s.errorString("unexpected newline")
+ return
+ }
+ if !unicode.IsSpace(rune) {
+ s.UnreadRune()
+ break
+ }
+ }
+}
+
+
+// token returns the next space-delimited string from the input. It
+// skips white space. For Scanln, it stops at newlines. For Scan,
+// newlines are treated as spaces.
+func (s *ss) token(skipSpace bool, f func(int) bool) []byte {
+ if skipSpace {
+ s.skipSpace(false)
+ }
+ // read until white space or newline
+ for {
+ rune := s.getRune()
+ if rune == eof {
+ break
+ }
+ if !f(rune) {
+ s.UnreadRune()
+ break
+ }
+ s.buf.WriteRune(rune)
+ }
+ return s.buf.Bytes()
+}
+
+// typeError indicates that the type of the operand did not match the format
+func (s *ss) typeError(field interface{}, expected string) {
+ s.errorString("expected field of type pointer to " + expected + "; found " + reflect.Typeof(field).String())
+}
+
+var complexError = os.ErrorString("syntax error scanning complex number")
+var boolError = os.ErrorString("syntax error scanning boolean")
+
+// consume reads the next rune in the input and reports whether it is in the ok string.
+// If accept is true, it puts the character into the input token.
+func (s *ss) consume(ok string, accept bool) bool {
+ rune := s.getRune()
+ if rune == eof {
+ return false
+ }
+ if strings.IndexRune(ok, rune) >= 0 {
+ if accept {
+ s.buf.WriteRune(rune)
+ }
+ return true
+ }
+ if rune != eof && accept {
+ s.UnreadRune()
+ }
+ return false
+}
+
+// peek reports whether the next character is in the ok string, without consuming it.
+func (s *ss) peek(ok string) bool {
+ rune := s.getRune()
+ if rune != eof {
+ s.UnreadRune()
+ }
+ return strings.IndexRune(ok, rune) >= 0
+}
+
+// accept checks the next rune in the input. If it's a byte (sic) in the string, it puts it in the
+// buffer and returns true. Otherwise it return false.
+func (s *ss) accept(ok string) bool {
+ return s.consume(ok, true)
+}
+
+// okVerb verifies that the verb is present in the list, setting s.err appropriately if not.
+func (s *ss) okVerb(verb int, okVerbs, typ string) bool {
+ for _, v := range okVerbs {
+ if v == verb {
+ return true
+ }
+ }
+ s.errorString("bad verb %" + string(verb) + " for " + typ)
+ return false
+}
+
+// scanBool returns the value of the boolean represented by the next token.
+func (s *ss) scanBool(verb int) bool {
+ if !s.okVerb(verb, "tv", "boolean") {
+ return false
+ }
+ // Syntax-checking a boolean is annoying. We're not fastidious about case.
+ switch s.mustReadRune() {
+ case '0':
+ return false
+ case '1':
+ return true
+ case 't', 'T':
+ if s.accept("rR") && (!s.accept("uU") || !s.accept("eE")) {
+ s.error(boolError)
+ }
+ return true
+ case 'f', 'F':
+ if s.accept("aL") && (!s.accept("lL") || !s.accept("sS") || !s.accept("eE")) {
+ s.error(boolError)
+ }
+ return false
+ }
+ return false
+}
+
+// Numerical elements
+const (
+ binaryDigits = "01"
+ octalDigits = "01234567"
+ decimalDigits = "0123456789"
+ hexadecimalDigits = "0123456789aAbBcCdDeEfF"
+ sign = "+-"
+ period = "."
+ exponent = "eEp"
+)
+
+// getBase returns the numeric base represented by the verb and its digit string.
+func (s *ss) getBase(verb int) (base int, digits string) {
+ s.okVerb(verb, "bdoUxXv", "integer") // sets s.err
+ base = 10
+ digits = decimalDigits
+ switch verb {
+ case 'b':
+ base = 2
+ digits = binaryDigits
+ case 'o':
+ base = 8
+ digits = octalDigits
+ case 'x', 'X', 'U':
+ base = 16
+ digits = hexadecimalDigits
+ }
+ return
+}
+
+// scanNumber returns the numerical string with specified digits starting here.
+func (s *ss) scanNumber(digits string, haveDigits bool) string {
+ if !haveDigits && !s.accept(digits) {
+ s.errorString("expected integer")
+ }
+ for s.accept(digits) {
+ }
+ return s.buf.String()
+}
+
+// scanRune returns the next rune value in the input.
+func (s *ss) scanRune(bitSize int) int64 {
+ rune := int64(s.mustReadRune())
+ n := uint(bitSize)
+ x := (rune << (64 - n)) >> (64 - n)
+ if x != rune {
+ s.errorString("overflow on character value " + string(rune))
+ }
+ return rune
+}
+
+// scanBasePrefix reports whether the integer begins with a 0 or 0x,
+// and returns the base, digit string, and whether a zero was found.
+// It is called only if the verb is %v.
+func (s *ss) scanBasePrefix() (base int, digits string, found bool) {
+ if !s.peek("0") {
+ return 10, decimalDigits, false
+ }
+ s.accept("0")
+ found = true // We've put a digit into the token buffer.
+ // Special cases for '0' && '0x'
+ base, digits = 8, octalDigits
+ if s.peek("xX") {
+ s.consume("xX", false)
+ base, digits = 16, hexadecimalDigits
+ }
+ return
+}
+
+// scanInt returns the value of the integer represented by the next
+// token, checking for overflow. Any error is stored in s.err.
+func (s *ss) scanInt(verb int, bitSize int) int64 {
+ if verb == 'c' {
+ return s.scanRune(bitSize)
+ }
+ s.skipSpace(false)
+ base, digits := s.getBase(verb)
+ haveDigits := false
+ if verb == 'U' {
+ if !s.consume("U", false) || !s.consume("+", false) {
+ s.errorString("bad unicode format ")
+ }
+ } else {
+ s.accept(sign) // If there's a sign, it will be left in the token buffer.
+ if verb == 'v' {
+ base, digits, haveDigits = s.scanBasePrefix()
+ }
+ }
+ tok := s.scanNumber(digits, haveDigits)
+ i, err := strconv.Btoi64(tok, base)
+ if err != nil {
+ s.error(err)
+ }
+ n := uint(bitSize)
+ x := (i << (64 - n)) >> (64 - n)
+ if x != i {
+ s.errorString("integer overflow on token " + tok)
+ }
+ return i
+}
+
+// scanUint returns the value of the unsigned integer represented
+// by the next token, checking for overflow. Any error is stored in s.err.
+func (s *ss) scanUint(verb int, bitSize int) uint64 {
+ if verb == 'c' {
+ return uint64(s.scanRune(bitSize))
+ }
+ s.skipSpace(false)
+ base, digits := s.getBase(verb)
+ haveDigits := false
+ if verb == 'U' {
+ if !s.consume("U", false) || !s.consume("+", false) {
+ s.errorString("bad unicode format ")
+ }
+ } else if verb == 'v' {
+ base, digits, haveDigits = s.scanBasePrefix()
+ }
+ tok := s.scanNumber(digits, haveDigits)
+ i, err := strconv.Btoui64(tok, base)
+ if err != nil {
+ s.error(err)
+ }
+ n := uint(bitSize)
+ x := (i << (64 - n)) >> (64 - n)
+ if x != i {
+ s.errorString("unsigned integer overflow on token " + tok)
+ }
+ return i
+}
+
+// floatToken returns the floating-point number starting here, no longer than swid
+// if the width is specified. It's not rigorous about syntax because it doesn't check that
+// we have at least some digits, but Atof will do that.
+func (s *ss) floatToken() string {
+ s.buf.Reset()
+ // NaN?
+ if s.accept("nN") && s.accept("aA") && s.accept("nN") {
+ return s.buf.String()
+ }
+ // leading sign?
+ s.accept(sign)
+ // Inf?
+ if s.accept("iI") && s.accept("nN") && s.accept("fF") {
+ return s.buf.String()
+ }
+ // digits?
+ for s.accept(decimalDigits) {
+ }
+ // decimal point?
+ if s.accept(period) {
+ // fraction?
+ for s.accept(decimalDigits) {
+ }
+ }
+ // exponent?
+ if s.accept(exponent) {
+ // leading sign?
+ s.accept(sign)
+ // digits?
+ for s.accept(decimalDigits) {
+ }
+ }
+ return s.buf.String()
+}
+
+// complexTokens returns the real and imaginary parts of the complex number starting here.
+// The number might be parenthesized and has the format (N+Ni) where N is a floating-point
+// number and there are no spaces within.
+func (s *ss) complexTokens() (real, imag string) {
+ // TODO: accept N and Ni independently?
+ parens := s.accept("(")
+ real = s.floatToken()
+ s.buf.Reset()
+ // Must now have a sign.
+ if !s.accept("+-") {
+ s.error(complexError)
+ }
+ // Sign is now in buffer
+ imagSign := s.buf.String()
+ imag = s.floatToken()
+ if !s.accept("i") {
+ s.error(complexError)
+ }
+ if parens && !s.accept(")") {
+ s.error(complexError)
+ }
+ return real, imagSign + imag
+}
+
+// convertFloat converts the string to a float64value.
+func (s *ss) convertFloat(str string, n int) float64 {
+ if p := strings.Index(str, "p"); p >= 0 {
+ // Atof doesn't handle power-of-2 exponents,
+ // but they're easy to evaluate.
+ f, err := strconv.AtofN(str[:p], n)
+ if err != nil {
+ // Put full string into error.
+ if e, ok := err.(*strconv.NumError); ok {
+ e.Num = str
+ }
+ s.error(err)
+ }
+ n, err := strconv.Atoi(str[p+1:])
+ if err != nil {
+ // Put full string into error.
+ if e, ok := err.(*strconv.NumError); ok {
+ e.Num = str
+ }
+ s.error(err)
+ }
+ return math.Ldexp(f, n)
+ }
+ f, err := strconv.AtofN(str, n)
+ if err != nil {
+ s.error(err)
+ }
+ return f
+}
+
+// convertComplex converts the next token to a complex128 value.
+// The atof argument is a type-specific reader for the underlying type.
+// If we're reading complex64, atof will parse float32s and convert them
+// to float64's to avoid reproducing this code for each complex type.
+func (s *ss) scanComplex(verb int, n int) complex128 {
+ if !s.okVerb(verb, floatVerbs, "complex") {
+ return 0
+ }
+ s.skipSpace(false)
+ sreal, simag := s.complexTokens()
+ real := s.convertFloat(sreal, n/2)
+ imag := s.convertFloat(simag, n/2)
+ return complex(real, imag)
+}
+
+// convertString returns the string represented by the next input characters.
+// The format of the input is determined by the verb.
+func (s *ss) convertString(verb int) (str string) {
+ if !s.okVerb(verb, "svqx", "string") {
+ return ""
+ }
+ s.skipSpace(false)
+ switch verb {
+ case 'q':
+ str = s.quotedString()
+ case 'x':
+ str = s.hexString()
+ default:
+ str = string(s.token(true, notSpace)) // %s and %v just return the next word
+ }
+ // Empty strings other than with %q are not OK.
+ if len(str) == 0 && verb != 'q' && s.maxWid > 0 {
+ s.errorString("Scan: no data for string")
+ }
+ return
+}
+
+// quotedString returns the double- or back-quoted string represented by the next input characters.
+func (s *ss) quotedString() string {
+ quote := s.mustReadRune()
+ switch quote {
+ case '`':
+ // Back-quoted: Anything goes until EOF or back quote.
+ for {
+ rune := s.mustReadRune()
+ if rune == quote {
+ break
+ }
+ s.buf.WriteRune(rune)
+ }
+ return s.buf.String()
+ case '"':
+ // Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes.
+ s.buf.WriteRune(quote)
+ for {
+ rune := s.mustReadRune()
+ s.buf.WriteRune(rune)
+ if rune == '\\' {
+ // In a legal backslash escape, no matter how long, only the character
+ // immediately after the escape can itself be a backslash or quote.
+ // Thus we only need to protect the first character after the backslash.
+ rune := s.mustReadRune()
+ s.buf.WriteRune(rune)
+ } else if rune == '"' {
+ break
+ }
+ }
+ result, err := strconv.Unquote(s.buf.String())
+ if err != nil {
+ s.error(err)
+ }
+ return result
+ default:
+ s.errorString("expected quoted string")
+ }
+ return ""
+}
+
+// hexDigit returns the value of the hexadecimal digit
+func (s *ss) hexDigit(digit int) int {
+ switch digit {
+ case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
+ return digit - '0'
+ case 'a', 'b', 'c', 'd', 'e', 'f':
+ return 10 + digit - 'a'
+ case 'A', 'B', 'C', 'D', 'E', 'F':
+ return 10 + digit - 'A'
+ }
+ s.errorString("Scan: illegal hex digit")
+ return 0
+}
+
+// hexByte returns the next hex-encoded (two-character) byte from the input.
+// There must be either two hexadecimal digits or a space character in the input.
+func (s *ss) hexByte() (b byte, ok bool) {
+ rune1 := s.getRune()
+ if rune1 == eof {
+ return
+ }
+ if unicode.IsSpace(rune1) {
+ s.UnreadRune()
+ return
+ }
+ rune2 := s.mustReadRune()
+ return byte(s.hexDigit(rune1)<<4 | s.hexDigit(rune2)), true
+}
+
+// hexString returns the space-delimited hexpair-encoded string.
+func (s *ss) hexString() string {
+ for {
+ b, ok := s.hexByte()
+ if !ok {
+ break
+ }
+ s.buf.WriteByte(b)
+ }
+ if s.buf.Len() == 0 {
+ s.errorString("Scan: no hex data for %x string")
+ return ""
+ }
+ return s.buf.String()
+}
+
+const floatVerbs = "beEfFgGv"
+
+const hugeWid = 1 << 30
+
+// scanOne scans a single value, deriving the scanner from the type of the argument.
+func (s *ss) scanOne(verb int, field interface{}) {
+ s.buf.Reset()
+ var err os.Error
+ // If the parameter has its own Scan method, use that.
+ if v, ok := field.(Scanner); ok {
+ err = v.Scan(s, verb)
+ if err != nil {
+ if err == os.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ s.error(err)
+ }
+ return
+ }
+ switch v := field.(type) {
+ case *bool:
+ *v = s.scanBool(verb)
+ case *complex64:
+ *v = complex64(s.scanComplex(verb, 64))
+ case *complex128:
+ *v = s.scanComplex(verb, 128)
+ case *int:
+ *v = int(s.scanInt(verb, intBits))
+ case *int8:
+ *v = int8(s.scanInt(verb, 8))
+ case *int16:
+ *v = int16(s.scanInt(verb, 16))
+ case *int32:
+ *v = int32(s.scanInt(verb, 32))
+ case *int64:
+ *v = s.scanInt(verb, 64)
+ case *uint:
+ *v = uint(s.scanUint(verb, intBits))
+ case *uint8:
+ *v = uint8(s.scanUint(verb, 8))
+ case *uint16:
+ *v = uint16(s.scanUint(verb, 16))
+ case *uint32:
+ *v = uint32(s.scanUint(verb, 32))
+ case *uint64:
+ *v = s.scanUint(verb, 64)
+ case *uintptr:
+ *v = uintptr(s.scanUint(verb, uintptrBits))
+ // Floats are tricky because you want to scan in the precision of the result, not
+ // scan in high precision and convert, in order to preserve the correct error condition.
+ case *float32:
+ if s.okVerb(verb, floatVerbs, "float32") {
+ s.skipSpace(false)
+ *v = float32(s.convertFloat(s.floatToken(), 32))
+ }
+ case *float64:
+ if s.okVerb(verb, floatVerbs, "float64") {
+ s.skipSpace(false)
+ *v = s.convertFloat(s.floatToken(), 64)
+ }
+ case *string:
+ *v = s.convertString(verb)
+ case *[]byte:
+ // We scan to string and convert so we get a copy of the data.
+ // If we scanned to bytes, the slice would point at the buffer.
+ *v = []byte(s.convertString(verb))
+ default:
+ val := reflect.NewValue(v)
+ ptr := val
+ if ptr.Kind() != reflect.Ptr {
+ s.errorString("Scan: type not a pointer: " + val.Type().String())
+ return
+ }
+ switch v := ptr.Elem(); v.Kind() {
+ case reflect.Bool:
+ v.SetBool(s.scanBool(verb))
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ v.SetInt(s.scanInt(verb, v.Type().Bits()))
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ v.SetUint(s.scanUint(verb, v.Type().Bits()))
+ case reflect.String:
+ v.SetString(s.convertString(verb))
+ case reflect.Slice:
+ // For now, can only handle (renamed) []byte.
+ typ := v.Type()
+ if typ.Elem().Kind() != reflect.Uint8 {
+ goto CantHandle
+ }
+ str := s.convertString(verb)
+ v.Set(reflect.MakeSlice(typ, len(str), len(str)))
+ for i := 0; i < len(str); i++ {
+ v.Index(i).SetUint(uint64(str[i]))
+ }
+ case reflect.Float32, reflect.Float64:
+ s.skipSpace(false)
+ v.SetFloat(s.convertFloat(s.floatToken(), v.Type().Bits()))
+ case reflect.Complex64, reflect.Complex128:
+ v.SetComplex(s.scanComplex(verb, v.Type().Bits()))
+ default:
+ CantHandle:
+ s.errorString("Scan: can't handle type: " + val.Type().String())
+ }
+ }
+}
+
+// errorHandler turns local panics into error returns. EOFs are benign.
+func errorHandler(errp *os.Error) {
+ if e := recover(); e != nil {
+ if se, ok := e.(scanError); ok { // catch local error
+ if se.err != os.EOF {
+ *errp = se.err
+ }
+ } else {
+ panic(e)
+ }
+ }
+}
+
+// doScan does the real work for scanning without a format string.
+func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) {
+ defer errorHandler(&err)
+ for _, field := range a {
+ s.scanOne('v', field)
+ numProcessed++
+ }
+ // Check for newline if required.
+ if !s.nlIsSpace {
+ for {
+ rune := s.getRune()
+ if rune == '\n' || rune == eof {
+ break
+ }
+ if !unicode.IsSpace(rune) {
+ s.errorString("Scan: expected newline")
+ break
+ }
+ }
+ }
+ return
+}
+
+// advance determines whether the next characters in the input match
+// those of the format. It returns the number of bytes (sic) consumed
+// in the format. Newlines included, all runs of space characters in
+// either input or format behave as a single space. This routine also
+// handles the %% case. If the return value is zero, either format
+// starts with a % (with no following %) or the input is empty.
+// If it is negative, the input did not match the string.
+func (s *ss) advance(format string) (i int) {
+ for i < len(format) {
+ fmtc, w := utf8.DecodeRuneInString(format[i:])
+ if fmtc == '%' {
+ // %% acts like a real percent
+ nextc, _ := utf8.DecodeRuneInString(format[i+w:]) // will not match % if string is empty
+ if nextc != '%' {
+ return
+ }
+ i += w // skip the first %
+ }
+ sawSpace := false
+ for unicode.IsSpace(fmtc) && i < len(format) {
+ sawSpace = true
+ i += w
+ fmtc, w = utf8.DecodeRuneInString(format[i:])
+ }
+ if sawSpace {
+ // There was space in the format, so there should be space (EOF)
+ // in the input.
+ inputc := s.getRune()
+ if inputc == eof {
+ return
+ }
+ if !unicode.IsSpace(inputc) {
+ // Space in format but not in input: error
+ s.errorString("expected space in input to match format")
+ }
+ s.skipSpace(true)
+ continue
+ }
+ inputc := s.mustReadRune()
+ if fmtc != inputc {
+ s.UnreadRune()
+ return -1
+ }
+ i += w
+ }
+ return
+}
+
+// doScanf does the real work when scanning with a format string.
+// At the moment, it handles only pointers to basic types.
+func (s *ss) doScanf(format string, a []interface{}) (numProcessed int, err os.Error) {
+ defer errorHandler(&err)
+ end := len(format) - 1
+ // We process one item per non-trivial format
+ for i := 0; i <= end; {
+ w := s.advance(format[i:])
+ if w > 0 {
+ i += w
+ continue
+ }
+ // Either we failed to advance, we have a percent character, or we ran out of input.
+ if format[i] != '%' {
+ // Can't advance format. Why not?
+ if w < 0 {
+ s.errorString("input does not match format")
+ }
+ // Otherwise at EOF; "too many operands" error handled below
+ break
+ }
+ i++ // % is one byte
+
+ // do we have 20 (width)?
+ var widPresent bool
+ s.maxWid, widPresent, i = parsenum(format, i, end)
+ if !widPresent {
+ s.maxWid = hugeWid
+ }
+ s.fieldLimit = s.limit
+ if f := s.count + s.maxWid; f < s.fieldLimit {
+ s.fieldLimit = f
+ }
+
+ c, w := utf8.DecodeRuneInString(format[i:])
+ i += w
+
+ if numProcessed >= len(a) { // out of operands
+ s.errorString("too few operands for format %" + format[i-w:])
+ break
+ }
+ field := a[numProcessed]
+
+ s.scanOne(c, field)
+ numProcessed++
+ s.fieldLimit = s.limit
+ }
+ if numProcessed < len(a) {
+ s.errorString("too many operands")
+ }
+ return
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package aids in the testing of code that uses channels.
+package script
+
+import (
+ "fmt"
+ "os"
+ "rand"
+ "reflect"
+ "strings"
+)
+
+// An Event is an element in a partially ordered set that either sends a value
+// to a channel or expects a value from a channel.
+type Event struct {
+ name string
+ occurred bool
+ predecessors []*Event
+ action action
+}
+
+type action interface {
+ // getSend returns nil if the action is not a send action.
+ getSend() sendAction
+ // getRecv returns nil if the action is not a receive action.
+ getRecv() recvAction
+ // getChannel returns the channel that the action operates on.
+ getChannel() interface{}
+}
+
+type recvAction interface {
+ recvMatch(interface{}) bool
+}
+
+type sendAction interface {
+ send()
+}
+
+// isReady returns true if all the predecessors of an Event have occurred.
+func (e Event) isReady() bool {
+ for _, predecessor := range e.predecessors {
+ if !predecessor.occurred {
+ return false
+ }
+ }
+
+ return true
+}
+
+// A Recv action reads a value from a channel and uses reflect.DeepMatch to
+// compare it with an expected value.
+type Recv struct {
+ Channel interface{}
+ Expected interface{}
+}
+
+func (r Recv) getRecv() recvAction { return r }
+
+func (Recv) getSend() sendAction { return nil }
+
+func (r Recv) getChannel() interface{} { return r.Channel }
+
+func (r Recv) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelRecv)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return reflect.DeepEqual(c.value, r.Expected)
+}
+
+// A RecvMatch action reads a value from a channel and calls a function to
+// determine if the value matches.
+type RecvMatch struct {
+ Channel interface{}
+ Match func(interface{}) bool
+}
+
+func (r RecvMatch) getRecv() recvAction { return r }
+
+func (RecvMatch) getSend() sendAction { return nil }
+
+func (r RecvMatch) getChannel() interface{} { return r.Channel }
+
+func (r RecvMatch) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelRecv)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return r.Match(c.value)
+}
+
+// A Closed action matches if the given channel is closed. The closing is
+// treated as an event, not a state, thus Closed will only match once for a
+// given channel.
+type Closed struct {
+ Channel interface{}
+}
+
+func (r Closed) getRecv() recvAction { return r }
+
+func (Closed) getSend() sendAction { return nil }
+
+func (r Closed) getChannel() interface{} { return r.Channel }
+
+func (r Closed) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelClosed)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return true
+}
+
+// A Send action sends a value to a channel. The value must match the
+// type of the channel exactly unless the channel if of type chan interface{}.
+type Send struct {
+ Channel interface{}
+ Value interface{}
+}
+
+func (Send) getRecv() recvAction { return nil }
+
+func (s Send) getSend() sendAction { return s }
+
+func (s Send) getChannel() interface{} { return s.Channel }
+
+type empty struct {
+ x interface{}
+}
+
+func newEmptyInterface(e empty) reflect.Value {
+ return reflect.NewValue(e).(*reflect.StructValue).Field(0)
+}
+
+func (s Send) send() {
+ // With reflect.ChanValue.Send, we must match the types exactly. So, if
+ // s.Channel is a chan interface{} we convert s.Value to an interface{}
+ // first.
+ c := reflect.NewValue(s.Channel).(*reflect.ChanValue)
+ var v reflect.Value
+ if iface, ok := c.Type().(*reflect.ChanType).Elem().(*reflect.InterfaceType); ok && iface.NumMethod() == 0 {
+ v = newEmptyInterface(empty{s.Value})
+ } else {
+ v = reflect.NewValue(s.Value)
+ }
+ c.Send(v)
+}
+
+// A Close action closes the given channel.
+type Close struct {
+ Channel interface{}
+}
+
+func (Close) getRecv() recvAction { return nil }
+
+func (s Close) getSend() sendAction { return s }
+
+func (s Close) getChannel() interface{} { return s.Channel }
+
+func (s Close) send() { reflect.NewValue(s.Channel).(*reflect.ChanValue).Close() }
+
+// A ReceivedUnexpected error results if no active Events match a value
+// received from a channel.
+type ReceivedUnexpected struct {
+ Value interface{}
+ ready []*Event
+}
+
+func (r ReceivedUnexpected) String() string {
+ names := make([]string, len(r.ready))
+ for i, v := range r.ready {
+ names[i] = v.name
+ }
+ return fmt.Sprintf("received unexpected value on one of the channels: %#v. Runnable events: %s", r.Value, strings.Join(names, ", "))
+}
+
+// A SetupError results if there is a error with the configuration of a set of
+// Events.
+type SetupError string
+
+func (s SetupError) String() string { return string(s) }
+
+func NewEvent(name string, predecessors []*Event, action action) *Event {
+ e := &Event{name, false, predecessors, action}
+ return e
+}
+
+// Given a set of Events, Perform repeatedly iterates over the set and finds the
+// subset of ready Events (that is, all of their predecessors have
+// occurred). From that subset, it pseudo-randomly selects an Event to perform.
+// If the Event is a send event, the send occurs and Perform recalculates the ready
+// set. If the event is a receive event, Perform waits for a value from any of the
+// channels that are contained in any of the events. That value is then matched
+// against the ready events. The first event that matches is considered to
+// have occurred and Perform recalculates the ready set.
+//
+// Perform continues this until all Events have occurred.
+//
+// Note that uncollected goroutines may still be reading from any of the
+// channels read from after Perform returns.
+//
+// For example, consider the problem of testing a function that reads values on
+// one channel and echos them to two output channels. To test this we would
+// create three events: a send event and two receive events. Each of the
+// receive events must list the send event as a predecessor but there is no
+// ordering between the receive events.
+//
+// send := NewEvent("send", nil, Send{c, 1})
+// recv1 := NewEvent("recv 1", []*Event{send}, Recv{c, 1})
+// recv2 := NewEvent("recv 2", []*Event{send}, Recv{c, 1})
+// Perform(0, []*Event{send, recv1, recv2})
+//
+// At first, only the send event would be in the ready set and thus Perform will
+// send a value to the input channel. Now the two receive events are ready and
+// Perform will match each of them against the values read from the output channels.
+//
+// It would be invalid to list one of the receive events as a predecessor of
+// the other. At each receive step, all the receive channels are considered,
+// thus Perform may see a value from a channel that is not in the current ready
+// set and fail.
+func Perform(seed int64, events []*Event) (err os.Error) {
+ r := rand.New(rand.NewSource(seed))
+
+ channels, err := getChannels(events)
+ if err != nil {
+ return
+ }
+ multiplex := make(chan interface{})
+ for _, channel := range channels {
+ go recvValues(multiplex, channel)
+ }
+
+Outer:
+ for {
+ ready, err := readyEvents(events)
+ if err != nil {
+ return err
+ }
+
+ if len(ready) == 0 {
+ // All events occurred.
+ break
+ }
+
+ event := ready[r.Intn(len(ready))]
+ if send := event.action.getSend(); send != nil {
+ send.send()
+ event.occurred = true
+ continue
+ }
+
+ v := <-multiplex
+ for _, event := range ready {
+ if recv := event.action.getRecv(); recv != nil && recv.recvMatch(v) {
+ event.occurred = true
+ continue Outer
+ }
+ }
+
+ return ReceivedUnexpected{v, ready}
+ }
+
+ return nil
+}
+
+// getChannels returns all the channels listed in any receive events.
+func getChannels(events []*Event) ([]interface{}, os.Error) {
+ channels := make([]interface{}, len(events))
+
+ j := 0
+ for _, event := range events {
+ if recv := event.action.getRecv(); recv == nil {
+ continue
+ }
+ c := event.action.getChannel()
+ if _, ok := reflect.NewValue(c).(*reflect.ChanValue); !ok {
+ return nil, SetupError("one of the channel values is not a channel")
+ }
+
+ duplicate := false
+ for _, other := range channels[0:j] {
+ if c == other {
+ duplicate = true
+ break
+ }
+ }
+
+ if !duplicate {
+ channels[j] = c
+ j++
+ }
+ }
+
+ return channels[0:j], nil
+}
+
+// recvValues is a multiplexing helper function. It reads values from the given
+// channel repeatedly, wrapping them up as either a channelRecv or
+// channelClosed structure, and forwards them to the multiplex channel.
+func recvValues(multiplex chan<- interface{}, channel interface{}) {
+ c := reflect.NewValue(channel).(*reflect.ChanValue)
+
+ for {
+ v, ok := c.Recv()
+ if !ok {
+ multiplex <- channelClosed{channel}
+ return
+ }
+
+ multiplex <- channelRecv{channel, v.Interface()}
+ }
+}
+
+type channelClosed struct {
+ channel interface{}
+}
+
+type channelRecv struct {
+ channel interface{}
+ value interface{}
+}
+
+// readyEvents returns the subset of events that are ready.
+func readyEvents(events []*Event) ([]*Event, os.Error) {
+ ready := make([]*Event, len(events))
+
+ j := 0
+ eventsWaiting := false
+ for _, event := range events {
+ if event.occurred {
+ continue
+ }
+
+ eventsWaiting = true
+ if event.isReady() {
+ ready[j] = event
+ j++
+ }
+ }
+
+ if j == 0 && eventsWaiting {
+ names := make([]string, len(events))
+ for _, event := range events {
+ if event.occurred {
+ continue
+ }
+ names[j] = event.name
+ }
+
+ return nil, SetupError("dependency cycle in events. These events are waiting to run but cannot: " + strings.Join(names, ", "))
+ }
+
+ return ready[0:j], nil
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package aids in the testing of code that uses channels.
+package script
+
+import (
+ "fmt"
+ "os"
+ "rand"
+ "reflect"
+ "strings"
+)
+
+// An Event is an element in a partially ordered set that either sends a value
+// to a channel or expects a value from a channel.
+type Event struct {
+ name string
+ occurred bool
+ predecessors []*Event
+ action action
+}
+
+type action interface {
+ // getSend returns nil if the action is not a send action.
+ getSend() sendAction
+ // getRecv returns nil if the action is not a receive action.
+ getRecv() recvAction
+ // getChannel returns the channel that the action operates on.
+ getChannel() interface{}
+}
+
+type recvAction interface {
+ recvMatch(interface{}) bool
+}
+
+type sendAction interface {
+ send()
+}
+
+// isReady returns true if all the predecessors of an Event have occurred.
+func (e Event) isReady() bool {
+ for _, predecessor := range e.predecessors {
+ if !predecessor.occurred {
+ return false
+ }
+ }
+
+ return true
+}
+
+// A Recv action reads a value from a channel and uses reflect.DeepMatch to
+// compare it with an expected value.
+type Recv struct {
+ Channel interface{}
+ Expected interface{}
+}
+
+func (r Recv) getRecv() recvAction { return r }
+
+func (Recv) getSend() sendAction { return nil }
+
+func (r Recv) getChannel() interface{} { return r.Channel }
+
+func (r Recv) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelRecv)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return reflect.DeepEqual(c.value, r.Expected)
+}
+
+// A RecvMatch action reads a value from a channel and calls a function to
+// determine if the value matches.
+type RecvMatch struct {
+ Channel interface{}
+ Match func(interface{}) bool
+}
+
+func (r RecvMatch) getRecv() recvAction { return r }
+
+func (RecvMatch) getSend() sendAction { return nil }
+
+func (r RecvMatch) getChannel() interface{} { return r.Channel }
+
+func (r RecvMatch) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelRecv)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return r.Match(c.value)
+}
+
+// A Closed action matches if the given channel is closed. The closing is
+// treated as an event, not a state, thus Closed will only match once for a
+// given channel.
+type Closed struct {
+ Channel interface{}
+}
+
+func (r Closed) getRecv() recvAction { return r }
+
+func (Closed) getSend() sendAction { return nil }
+
+func (r Closed) getChannel() interface{} { return r.Channel }
+
+func (r Closed) recvMatch(chanEvent interface{}) bool {
+ c, ok := chanEvent.(channelClosed)
+ if !ok || c.channel != r.Channel {
+ return false
+ }
+
+ return true
+}
+
+// A Send action sends a value to a channel. The value must match the
+// type of the channel exactly unless the channel if of type chan interface{}.
+type Send struct {
+ Channel interface{}
+ Value interface{}
+}
+
+func (Send) getRecv() recvAction { return nil }
+
+func (s Send) getSend() sendAction { return s }
+
+func (s Send) getChannel() interface{} { return s.Channel }
+
+type empty struct {
+ x interface{}
+}
+
+func newEmptyInterface(e empty) reflect.Value {
+ return reflect.NewValue(e).Field(0)
+}
+
+func (s Send) send() {
+ // With reflect.ChanValue.Send, we must match the types exactly. So, if
+ // s.Channel is a chan interface{} we convert s.Value to an interface{}
+ // first.
+ c := reflect.NewValue(s.Channel)
+ var v reflect.Value
+ if iface := c.Type().Elem(); iface.Kind() == reflect.Interface && iface.NumMethod() == 0 {
+ v = newEmptyInterface(empty{s.Value})
+ } else {
+ v = reflect.NewValue(s.Value)
+ }
+ c.Send(v)
+}
+
+// A Close action closes the given channel.
+type Close struct {
+ Channel interface{}
+}
+
+func (Close) getRecv() recvAction { return nil }
+
+func (s Close) getSend() sendAction { return s }
+
+func (s Close) getChannel() interface{} { return s.Channel }
+
+func (s Close) send() { reflect.NewValue(s.Channel).Close() }
+
+// A ReceivedUnexpected error results if no active Events match a value
+// received from a channel.
+type ReceivedUnexpected struct {
+ Value interface{}
+ ready []*Event
+}
+
+func (r ReceivedUnexpected) String() string {
+ names := make([]string, len(r.ready))
+ for i, v := range r.ready {
+ names[i] = v.name
+ }
+ return fmt.Sprintf("received unexpected value on one of the channels: %#v. Runnable events: %s", r.Value, strings.Join(names, ", "))
+}
+
+// A SetupError results if there is a error with the configuration of a set of
+// Events.
+type SetupError string
+
+func (s SetupError) String() string { return string(s) }
+
+func NewEvent(name string, predecessors []*Event, action action) *Event {
+ e := &Event{name, false, predecessors, action}
+ return e
+}
+
+// Given a set of Events, Perform repeatedly iterates over the set and finds the
+// subset of ready Events (that is, all of their predecessors have
+// occurred). From that subset, it pseudo-randomly selects an Event to perform.
+// If the Event is a send event, the send occurs and Perform recalculates the ready
+// set. If the event is a receive event, Perform waits for a value from any of the
+// channels that are contained in any of the events. That value is then matched
+// against the ready events. The first event that matches is considered to
+// have occurred and Perform recalculates the ready set.
+//
+// Perform continues this until all Events have occurred.
+//
+// Note that uncollected goroutines may still be reading from any of the
+// channels read from after Perform returns.
+//
+// For example, consider the problem of testing a function that reads values on
+// one channel and echos them to two output channels. To test this we would
+// create three events: a send event and two receive events. Each of the
+// receive events must list the send event as a predecessor but there is no
+// ordering between the receive events.
+//
+// send := NewEvent("send", nil, Send{c, 1})
+// recv1 := NewEvent("recv 1", []*Event{send}, Recv{c, 1})
+// recv2 := NewEvent("recv 2", []*Event{send}, Recv{c, 1})
+// Perform(0, []*Event{send, recv1, recv2})
+//
+// At first, only the send event would be in the ready set and thus Perform will
+// send a value to the input channel. Now the two receive events are ready and
+// Perform will match each of them against the values read from the output channels.
+//
+// It would be invalid to list one of the receive events as a predecessor of
+// the other. At each receive step, all the receive channels are considered,
+// thus Perform may see a value from a channel that is not in the current ready
+// set and fail.
+func Perform(seed int64, events []*Event) (err os.Error) {
+ r := rand.New(rand.NewSource(seed))
+
+ channels, err := getChannels(events)
+ if err != nil {
+ return
+ }
+ multiplex := make(chan interface{})
+ for _, channel := range channels {
+ go recvValues(multiplex, channel)
+ }
+
+Outer:
+ for {
+ ready, err := readyEvents(events)
+ if err != nil {
+ return err
+ }
+
+ if len(ready) == 0 {
+ // All events occurred.
+ break
+ }
+
+ event := ready[r.Intn(len(ready))]
+ if send := event.action.getSend(); send != nil {
+ send.send()
+ event.occurred = true
+ continue
+ }
+
+ v := <-multiplex
+ for _, event := range ready {
+ if recv := event.action.getRecv(); recv != nil && recv.recvMatch(v) {
+ event.occurred = true
+ continue Outer
+ }
+ }
+
+ return ReceivedUnexpected{v, ready}
+ }
+
+ return nil
+}
+
+// getChannels returns all the channels listed in any receive events.
+func getChannels(events []*Event) ([]interface{}, os.Error) {
+ channels := make([]interface{}, len(events))
+
+ j := 0
+ for _, event := range events {
+ if recv := event.action.getRecv(); recv == nil {
+ continue
+ }
+ c := event.action.getChannel()
+ if reflect.NewValue(c).Kind() != reflect.Chan {
+ return nil, SetupError("one of the channel values is not a channel")
+ }
+
+ duplicate := false
+ for _, other := range channels[0:j] {
+ if c == other {
+ duplicate = true
+ break
+ }
+ }
+
+ if !duplicate {
+ channels[j] = c
+ j++
+ }
+ }
+
+ return channels[0:j], nil
+}
+
+// recvValues is a multiplexing helper function. It reads values from the given
+// channel repeatedly, wrapping them up as either a channelRecv or
+// channelClosed structure, and forwards them to the multiplex channel.
+func recvValues(multiplex chan<- interface{}, channel interface{}) {
+ c := reflect.NewValue(channel)
+
+ for {
+ v, ok := c.Recv()
+ if !ok {
+ multiplex <- channelClosed{channel}
+ return
+ }
+
+ multiplex <- channelRecv{channel, v.Interface()}
+ }
+}
+
+type channelClosed struct {
+ channel interface{}
+}
+
+type channelRecv struct {
+ channel interface{}
+ value interface{}
+}
+
+// readyEvents returns the subset of events that are ready.
+func readyEvents(events []*Event) ([]*Event, os.Error) {
+ ready := make([]*Event, len(events))
+
+ j := 0
+ eventsWaiting := false
+ for _, event := range events {
+ if event.occurred {
+ continue
+ }
+
+ eventsWaiting = true
+ if event.isReady() {
+ ready[j] = event
+ j++
+ }
+ }
+
+ if j == 0 && eventsWaiting {
+ names := make([]string, len(events))
+ for _, event := range events {
+ if event.occurred {
+ continue
+ }
+ names[j] = event.name
+ }
+
+ return nil, SetupError("dependency cycle in events. These events are waiting to run but cannot: " + strings.Join(names, ", "))
+ }
+
+ return ready[0:j], nil
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ Data-driven templates for generating textual output such as
+ HTML.
+
+ Templates are executed by applying them to a data structure.
+ Annotations in the template refer to elements of the data
+ structure (typically a field of a struct or a key in a map)
+ to control execution and derive values to be displayed.
+ The template walks the structure as it executes and the
+ "cursor" @ represents the value at the current location
+ in the structure.
+
+ Data items may be values or pointers; the interface hides the
+ indirection.
+
+ In the following, 'field' is one of several things, according to the data.
+
+ - The name of a field of a struct (result = data.field),
+ - The value stored in a map under that key (result = data[field]), or
+ - The result of invoking a niladic single-valued method with that name
+ (result = data.field())
+
+ Major constructs ({} are the default delimiters for template actions;
+ [] are the notation in this comment for optional elements):
+
+ {# comment }
+
+ A one-line comment.
+
+ {.section field} XXX [ {.or} YYY ] {.end}
+
+ Set @ to the value of the field. It may be an explicit @
+ to stay at the same point in the data. If the field is nil
+ or empty, execute YYY; otherwise execute XXX.
+
+ {.repeated section field} XXX [ {.alternates with} ZZZ ] [ {.or} YYY ] {.end}
+
+ Like .section, but field must be an array or slice. XXX
+ is executed for each element. If the array is nil or empty,
+ YYY is executed instead. If the {.alternates with} marker
+ is present, ZZZ is executed between iterations of XXX.
+
+ {field}
+ {field1 field2 ...}
+ {field|formatter}
+ {field1 field2...|formatter}
+ {field|formatter1|formatter2}
+
+ Insert the value of the fields into the output. Each field is
+ first looked for in the cursor, as in .section and .repeated.
+ If it is not found, the search continues in outer sections
+ until the top level is reached.
+
+ If the field value is a pointer, leading asterisks indicate
+ that the value to be inserted should be evaluated through the
+ pointer. For example, if x.p is of type *int, {x.p} will
+ insert the value of the pointer but {*x.p} will insert the
+ value of the underlying integer. If the value is nil or not a
+ pointer, asterisks have no effect.
+
+ If a formatter is specified, it must be named in the formatter
+ map passed to the template set up routines or in the default
+ set ("html","str","") and is used to process the data for
+ output. The formatter function has signature
+ func(wr io.Writer, formatter string, data ...interface{})
+ where wr is the destination for output, data holds the field
+ values at the instantiation, and formatter is its name at
+ the invocation site. The default formatter just concatenates
+ the string representations of the fields.
+
+ Multiple formatters separated by the pipeline character | are
+ executed sequentially, with each formatter receiving the bytes
+ emitted by the one to its left.
+
+ The delimiter strings get their default value, "{" and "}", from
+ JSON-template. They may be set to any non-empty, space-free
+ string using the SetDelims method. Their value can be printed
+ in the output using {.meta-left} and {.meta-right}.
+*/
+package template
+
+import (
+ "bytes"
+ "container/vector"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "reflect"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// Errors returned during parsing and execution. Users may extract the information and reformat
+// if they desire.
+type Error struct {
+ Line int
+ Msg string
+}
+
+func (e *Error) String() string { return fmt.Sprintf("line %d: %s", e.Line, e.Msg) }
+
+// Most of the literals are aces.
+var lbrace = []byte{'{'}
+var rbrace = []byte{'}'}
+var space = []byte{' '}
+var tab = []byte{'\t'}
+
+// The various types of "tokens", which are plain text or (usually) brace-delimited descriptors
+const (
+ tokAlternates = iota
+ tokComment
+ tokEnd
+ tokLiteral
+ tokOr
+ tokRepeated
+ tokSection
+ tokText
+ tokVariable
+)
+
+// FormatterMap is the type describing the mapping from formatter
+// names to the functions that implement them.
+type FormatterMap map[string]func(io.Writer, string, ...interface{})
+
+// Built-in formatters.
+var builtins = FormatterMap{
+ "html": HTMLFormatter,
+ "str": StringFormatter,
+ "": StringFormatter,
+}
+
+// The parsed state of a template is a vector of xxxElement structs.
+// Sections have line numbers so errors can be reported better during execution.
+
+// Plain text.
+type textElement struct {
+ text []byte
+}
+
+// A literal such as .meta-left or .meta-right
+type literalElement struct {
+ text []byte
+}
+
+// A variable invocation to be evaluated
+type variableElement struct {
+ linenum int
+ word []string // The fields in the invocation.
+ fmts []string // Names of formatters to apply. len(fmts) > 0
+}
+
+// A .section block, possibly with a .or
+type sectionElement struct {
+ linenum int // of .section itself
+ field string // cursor field for this block
+ start int // first element
+ or int // first element of .or block
+ end int // one beyond last element
+}
+
+// A .repeated block, possibly with a .or and a .alternates
+type repeatedElement struct {
+ sectionElement // It has the same structure...
+ altstart int // ... except for alternates
+ altend int
+}
+
+// Template is the type that represents a template definition.
+// It is unchanged after parsing.
+type Template struct {
+ fmap FormatterMap // formatters for variables
+ // Used during parsing:
+ ldelim, rdelim []byte // delimiters; default {}
+ buf []byte // input text to process
+ p int // position in buf
+ linenum int // position in input
+ // Parsed results:
+ elems *vector.Vector
+}
+
+// Internal state for executing a Template. As we evaluate the struct,
+// the data item descends into the fields associated with sections, etc.
+// Parent is used to walk upwards to find variables higher in the tree.
+type state struct {
+ parent *state // parent in hierarchy
+ data reflect.Value // the driver data for this section etc.
+ wr io.Writer // where to send output
+ buf [2]bytes.Buffer // alternating buffers used when chaining formatters
+}
+
+func (parent *state) clone(data reflect.Value) *state {
+ return &state{parent: parent, data: data, wr: parent.wr}
+}
+
+// New creates a new template with the specified formatter map (which
+// may be nil) to define auxiliary functions for formatting variables.
+func New(fmap FormatterMap) *Template {
+ t := new(Template)
+ t.fmap = fmap
+ t.ldelim = lbrace
+ t.rdelim = rbrace
+ t.elems = new(vector.Vector)
+ return t
+}
+
+// Report error and stop executing. The line number must be provided explicitly.
+func (t *Template) execError(st *state, line int, err string, args ...interface{}) {
+ panic(&Error{line, fmt.Sprintf(err, args...)})
+}
+
+// Report error, panic to terminate parsing.
+// The line number comes from the template state.
+func (t *Template) parseError(err string, args ...interface{}) {
+ panic(&Error{t.linenum, fmt.Sprintf(err, args...)})
+}
+
+// Is this an exported - upper case - name?
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// -- Lexical analysis
+
+// Is c a white space character?
+func white(c uint8) bool { return c == ' ' || c == '\t' || c == '\r' || c == '\n' }
+
+// Safely, does s[n:n+len(t)] == t?
+func equal(s []byte, n int, t []byte) bool {
+ b := s[n:]
+ if len(t) > len(b) { // not enough space left for a match.
+ return false
+ }
+ for i, c := range t {
+ if c != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// nextItem returns the next item from the input buffer. If the returned
+// item is empty, we are at EOF. The item will be either a
+// delimited string or a non-empty string between delimited
+// strings. Tokens stop at (but include, if plain text) a newline.
+// Action tokens on a line by themselves drop any space on
+// either side, up to and including the newline.
+func (t *Template) nextItem() []byte {
+ startOfLine := t.p == 0 || t.buf[t.p-1] == '\n'
+ start := t.p
+ var i int
+ newline := func() {
+ t.linenum++
+ i++
+ }
+ // Leading white space up to but not including newline
+ for i = start; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' || !white(t.buf[i]) {
+ break
+ }
+ }
+ leadingSpace := i > start
+ // What's left is nothing, newline, delimited string, or plain text
+ switch {
+ case i == len(t.buf):
+ // EOF; nothing to do
+ case t.buf[i] == '\n':
+ newline()
+ case equal(t.buf, i, t.ldelim):
+ left := i // Start of left delimiter.
+ right := -1 // Will be (immediately after) right delimiter.
+ haveText := false // Delimiters contain text.
+ i += len(t.ldelim)
+ // Find the end of the action.
+ for ; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' {
+ break
+ }
+ if equal(t.buf, i, t.rdelim) {
+ i += len(t.rdelim)
+ right = i
+ break
+ }
+ haveText = true
+ }
+ if right < 0 {
+ t.parseError("unmatched opening delimiter")
+ return nil
+ }
+ // Is this a special action (starts with '.' or '#') and the only thing on the line?
+ if startOfLine && haveText {
+ firstChar := t.buf[left+len(t.ldelim)]
+ if firstChar == '.' || firstChar == '#' {
+ // It's special and the first thing on the line. Is it the last?
+ for j := right; j < len(t.buf) && white(t.buf[j]); j++ {
+ if t.buf[j] == '\n' {
+ // Yes it is. Drop the surrounding space and return the {.foo}
+ t.linenum++
+ t.p = j + 1
+ return t.buf[left:right]
+ }
+ }
+ }
+ }
+ // No it's not. If there's leading space, return that.
+ if leadingSpace {
+ // not trimming space: return leading white space if there is some.
+ t.p = left
+ return t.buf[start:left]
+ }
+ // Return the word, leave the trailing space.
+ start = left
+ break
+ default:
+ for ; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' {
+ newline()
+ break
+ }
+ if equal(t.buf, i, t.ldelim) {
+ break
+ }
+ }
+ }
+ item := t.buf[start:i]
+ t.p = i
+ return item
+}
+
+// Turn a byte array into a white-space-split array of strings.
+func words(buf []byte) []string {
+ s := make([]string, 0, 5)
+ p := 0 // position in buf
+ // one word per loop
+ for i := 0; ; i++ {
+ // skip white space
+ for ; p < len(buf) && white(buf[p]); p++ {
+ }
+ // grab word
+ start := p
+ for ; p < len(buf) && !white(buf[p]); p++ {
+ }
+ if start == p { // no text left
+ break
+ }
+ s = append(s, string(buf[start:p]))
+ }
+ return s
+}
+
+// Analyze an item and return its token type and, if it's an action item, an array of
+// its constituent words.
+func (t *Template) analyze(item []byte) (tok int, w []string) {
+ // item is known to be non-empty
+ if !equal(item, 0, t.ldelim) { // doesn't start with left delimiter
+ tok = tokText
+ return
+ }
+ if !equal(item, len(item)-len(t.rdelim), t.rdelim) { // doesn't end with right delimiter
+ t.parseError("internal error: unmatched opening delimiter") // lexing should prevent this
+ return
+ }
+ if len(item) <= len(t.ldelim)+len(t.rdelim) { // no contents
+ t.parseError("empty directive")
+ return
+ }
+ // Comment
+ if item[len(t.ldelim)] == '#' {
+ tok = tokComment
+ return
+ }
+ // Split into words
+ w = words(item[len(t.ldelim) : len(item)-len(t.rdelim)]) // drop final delimiter
+ if len(w) == 0 {
+ t.parseError("empty directive")
+ return
+ }
+ if len(w) > 0 && w[0][0] != '.' {
+ tok = tokVariable
+ return
+ }
+ switch w[0] {
+ case ".meta-left", ".meta-right", ".space", ".tab":
+ tok = tokLiteral
+ return
+ case ".or":
+ tok = tokOr
+ return
+ case ".end":
+ tok = tokEnd
+ return
+ case ".section":
+ if len(w) != 2 {
+ t.parseError("incorrect fields for .section: %s", item)
+ return
+ }
+ tok = tokSection
+ return
+ case ".repeated":
+ if len(w) != 3 || w[1] != "section" {
+ t.parseError("incorrect fields for .repeated: %s", item)
+ return
+ }
+ tok = tokRepeated
+ return
+ case ".alternates":
+ if len(w) != 2 || w[1] != "with" {
+ t.parseError("incorrect fields for .alternates: %s", item)
+ return
+ }
+ tok = tokAlternates
+ return
+ }
+ t.parseError("bad directive: %s", item)
+ return
+}
+
+// formatter returns the Formatter with the given name in the Template, or nil if none exists.
+func (t *Template) formatter(name string) func(io.Writer, string, ...interface{}) {
+ if t.fmap != nil {
+ if fn := t.fmap[name]; fn != nil {
+ return fn
+ }
+ }
+ return builtins[name]
+}
+
+// -- Parsing
+
+// Allocate a new variable-evaluation element.
+func (t *Template) newVariable(words []string) *variableElement {
+ // After the final space-separated argument, formatters may be specified separated
+ // by pipe symbols, for example: {a b c|d|e}
+
+ // Until we learn otherwise, formatters contains a single name: "", the default formatter.
+ formatters := []string{""}
+ lastWord := words[len(words)-1]
+ bar := strings.IndexRune(lastWord, '|')
+ if bar >= 0 {
+ words[len(words)-1] = lastWord[0:bar]
+ formatters = strings.Split(lastWord[bar+1:], "|", -1)
+ }
+
+ // We could remember the function address here and avoid the lookup later,
+ // but it's more dynamic to let the user change the map contents underfoot.
+ // We do require the name to be present, though.
+
+ // Is it in user-supplied map?
+ for _, f := range formatters {
+ if t.formatter(f) == nil {
+ t.parseError("unknown formatter: %q", f)
+ }
+ }
+ return &variableElement{t.linenum, words, formatters}
+}
+
+// Grab the next item. If it's simple, just append it to the template.
+// Otherwise return its details.
+func (t *Template) parseSimple(item []byte) (done bool, tok int, w []string) {
+ tok, w = t.analyze(item)
+ done = true // assume for simplicity
+ switch tok {
+ case tokComment:
+ return
+ case tokText:
+ t.elems.Push(&textElement{item})
+ return
+ case tokLiteral:
+ switch w[0] {
+ case ".meta-left":
+ t.elems.Push(&literalElement{t.ldelim})
+ case ".meta-right":
+ t.elems.Push(&literalElement{t.rdelim})
+ case ".space":
+ t.elems.Push(&literalElement{space})
+ case ".tab":
+ t.elems.Push(&literalElement{tab})
+ default:
+ t.parseError("internal error: unknown literal: %s", w[0])
+ }
+ return
+ case tokVariable:
+ t.elems.Push(t.newVariable(w))
+ return
+ }
+ return false, tok, w
+}
+
+// parseRepeated and parseSection are mutually recursive
+
+func (t *Template) parseRepeated(words []string) *repeatedElement {
+ r := new(repeatedElement)
+ t.elems.Push(r)
+ r.linenum = t.linenum
+ r.field = words[2]
+ // Scan section, collecting true and false (.or) blocks.
+ r.start = t.elems.Len()
+ r.or = -1
+ r.altstart = -1
+ r.altend = -1
+Loop:
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ t.parseError("missing .end for .repeated section")
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokEnd:
+ break Loop
+ case tokOr:
+ if r.or >= 0 {
+ t.parseError("extra .or in .repeated section")
+ break Loop
+ }
+ r.altend = t.elems.Len()
+ r.or = t.elems.Len()
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ case tokAlternates:
+ if r.altstart >= 0 {
+ t.parseError("extra .alternates in .repeated section")
+ break Loop
+ }
+ if r.or >= 0 {
+ t.parseError(".alternates inside .or block in .repeated section")
+ break Loop
+ }
+ r.altstart = t.elems.Len()
+ default:
+ t.parseError("internal error: unknown repeated section item: %s", item)
+ break Loop
+ }
+ }
+ if r.altend < 0 {
+ r.altend = t.elems.Len()
+ }
+ r.end = t.elems.Len()
+ return r
+}
+
+func (t *Template) parseSection(words []string) *sectionElement {
+ s := new(sectionElement)
+ t.elems.Push(s)
+ s.linenum = t.linenum
+ s.field = words[1]
+ // Scan section, collecting true and false (.or) blocks.
+ s.start = t.elems.Len()
+ s.or = -1
+Loop:
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ t.parseError("missing .end for .section")
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokEnd:
+ break Loop
+ case tokOr:
+ if s.or >= 0 {
+ t.parseError("extra .or in .section")
+ break Loop
+ }
+ s.or = t.elems.Len()
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ case tokAlternates:
+ t.parseError(".alternates not in .repeated")
+ default:
+ t.parseError("internal error: unknown section item: %s", item)
+ }
+ }
+ s.end = t.elems.Len()
+ return s
+}
+
+func (t *Template) parse() {
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokOr, tokEnd, tokAlternates:
+ t.parseError("unexpected %s", w[0])
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ default:
+ t.parseError("internal error: bad directive in parse: %s", item)
+ }
+ }
+}
+
+// -- Execution
+
+// Evaluate interfaces and pointers looking for a value that can look up the name, via a
+// struct field, method, or map key, and return the result of the lookup.
+func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value {
+ for v != nil {
+ typ := v.Type()
+ if n := v.Type().NumMethod(); n > 0 {
+ for i := 0; i < n; i++ {
+ m := typ.Method(i)
+ mtyp := m.Type
+ if m.Name == name && mtyp.NumIn() == 1 && mtyp.NumOut() == 1 {
+ if !isExported(name) {
+ t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type())
+ }
+ return v.Method(i).Call(nil)[0]
+ }
+ }
+ }
+ switch av := v.(type) {
+ case *reflect.PtrValue:
+ v = av.Elem()
+ case *reflect.InterfaceValue:
+ v = av.Elem()
+ case *reflect.StructValue:
+ if !isExported(name) {
+ t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type())
+ }
+ return av.FieldByName(name)
+ case *reflect.MapValue:
+ if v := av.Elem(reflect.NewValue(name)); v != nil {
+ return v
+ }
+ return reflect.MakeZero(typ.(*reflect.MapType).Elem())
+ default:
+ return nil
+ }
+ }
+ return v
+}
+
+// indirectPtr returns the item numLevels levels of indirection below the value.
+// It is forgiving: if the value is not a pointer, it returns it rather than giving
+// an error. If the pointer is nil, it is returned as is.
+func indirectPtr(v reflect.Value, numLevels int) reflect.Value {
+ for i := numLevels; v != nil && i > 0; i++ {
+ if p, ok := v.(*reflect.PtrValue); ok {
+ if p.IsNil() {
+ return v
+ }
+ v = p.Elem()
+ } else {
+ break
+ }
+ }
+ return v
+}
+
+// Walk v through pointers and interfaces, extracting the elements within.
+func indirect(v reflect.Value) reflect.Value {
+loop:
+ for v != nil {
+ switch av := v.(type) {
+ case *reflect.PtrValue:
+ v = av.Elem()
+ case *reflect.InterfaceValue:
+ v = av.Elem()
+ default:
+ break loop
+ }
+ }
+ return v
+}
+
+// If the data for this template is a struct, find the named variable.
+// Names of the form a.b.c are walked down the data tree.
+// The special name "@" (the "cursor") denotes the current data.
+// The value coming in (st.data) might need indirecting to reach
+// a struct while the return value is not indirected - that is,
+// it represents the actual named field. Leading stars indicate
+// levels of indirection to be applied to the value.
+func (t *Template) findVar(st *state, s string) reflect.Value {
+ data := st.data
+ flattenedName := strings.TrimLeft(s, "*")
+ numStars := len(s) - len(flattenedName)
+ s = flattenedName
+ if s == "@" {
+ return indirectPtr(data, numStars)
+ }
+ for _, elem := range strings.Split(s, ".", -1) {
+ // Look up field; data must be a struct or map.
+ data = t.lookup(st, data, elem)
+ if data == nil {
+ return nil
+ }
+ }
+ return indirectPtr(data, numStars)
+}
+
+// Is there no data to look at?
+func empty(v reflect.Value) bool {
+ v = indirect(v)
+ if v == nil {
+ return true
+ }
+ switch v := v.(type) {
+ case *reflect.BoolValue:
+ return v.Get() == false
+ case *reflect.StringValue:
+ return v.Get() == ""
+ case *reflect.StructValue:
+ return false
+ case *reflect.MapValue:
+ return false
+ case *reflect.ArrayValue:
+ return v.Len() == 0
+ case *reflect.SliceValue:
+ return v.Len() == 0
+ }
+ return false
+}
+
+// Look up a variable or method, up through the parent if necessary.
+func (t *Template) varValue(name string, st *state) reflect.Value {
+ field := t.findVar(st, name)
+ if field == nil {
+ if st.parent == nil {
+ t.execError(st, t.linenum, "name not found: %s in type %s", name, st.data.Type())
+ }
+ return t.varValue(name, st.parent)
+ }
+ return field
+}
+
+func (t *Template) format(wr io.Writer, fmt string, val []interface{}, v *variableElement, st *state) {
+ fn := t.formatter(fmt)
+ if fn == nil {
+ t.execError(st, v.linenum, "missing formatter %s for variable %s", fmt, v.word[0])
+ }
+ fn(wr, fmt, val...)
+}
+
+// Evaluate a variable, looking up through the parent if necessary.
+// If it has a formatter attached ({var|formatter}) run that too.
+func (t *Template) writeVariable(v *variableElement, st *state) {
+ // Turn the words of the invocation into values.
+ val := make([]interface{}, len(v.word))
+ for i, word := range v.word {
+ val[i] = t.varValue(word, st).Interface()
+ }
+
+ for i, fmt := range v.fmts[:len(v.fmts)-1] {
+ b := &st.buf[i&1]
+ b.Reset()
+ t.format(b, fmt, val, v, st)
+ val = val[0:1]
+ val[0] = b.Bytes()
+ }
+ t.format(st.wr, v.fmts[len(v.fmts)-1], val, v, st)
+}
+
+// Execute element i. Return next index to execute.
+func (t *Template) executeElement(i int, st *state) int {
+ switch elem := t.elems.At(i).(type) {
+ case *textElement:
+ st.wr.Write(elem.text)
+ return i + 1
+ case *literalElement:
+ st.wr.Write(elem.text)
+ return i + 1
+ case *variableElement:
+ t.writeVariable(elem, st)
+ return i + 1
+ case *sectionElement:
+ t.executeSection(elem, st)
+ return elem.end
+ case *repeatedElement:
+ t.executeRepeated(elem, st)
+ return elem.end
+ }
+ e := t.elems.At(i)
+ t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.NewValue(e).Interface(), e)
+ return 0
+}
+
+// Execute the template.
+func (t *Template) execute(start, end int, st *state) {
+ for i := start; i < end; {
+ i = t.executeElement(i, st)
+ }
+}
+
+// Execute a .section
+func (t *Template) executeSection(s *sectionElement, st *state) {
+ // Find driver data for this section. It must be in the current struct.
+ field := t.varValue(s.field, st)
+ if field == nil {
+ t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, st.data.Type())
+ }
+ st = st.clone(field)
+ start, end := s.start, s.or
+ if !empty(field) {
+ // Execute the normal block.
+ if end < 0 {
+ end = s.end
+ }
+ } else {
+ // Execute the .or block. If it's missing, do nothing.
+ start, end = s.or, s.end
+ if start < 0 {
+ return
+ }
+ }
+ for i := start; i < end; {
+ i = t.executeElement(i, st)
+ }
+}
+
+// Return the result of calling the Iter method on v, or nil.
+func iter(v reflect.Value) *reflect.ChanValue {
+ for j := 0; j < v.Type().NumMethod(); j++ {
+ mth := v.Type().Method(j)
+ fv := v.Method(j)
+ ft := fv.Type().(*reflect.FuncType)
+ // TODO(rsc): NumIn() should return 0 here, because ft is from a curried FuncValue.
+ if mth.Name != "Iter" || ft.NumIn() != 1 || ft.NumOut() != 1 {
+ continue
+ }
+ ct, ok := ft.Out(0).(*reflect.ChanType)
+ if !ok || ct.Dir()&reflect.RecvDir == 0 {
+ continue
+ }
+ return fv.Call(nil)[0].(*reflect.ChanValue)
+ }
+ return nil
+}
+
+// Execute a .repeated section
+func (t *Template) executeRepeated(r *repeatedElement, st *state) {
+ // Find driver data for this section. It must be in the current struct.
+ field := t.varValue(r.field, st)
+ if field == nil {
+ t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, st.data.Type())
+ }
+ field = indirect(field)
+
+ start, end := r.start, r.or
+ if end < 0 {
+ end = r.end
+ }
+ if r.altstart >= 0 {
+ end = r.altstart
+ }
+ first := true
+
+ // Code common to all the loops.
+ loopBody := func(newst *state) {
+ // .alternates between elements
+ if !first && r.altstart >= 0 {
+ for i := r.altstart; i < r.altend; {
+ i = t.executeElement(i, newst)
+ }
+ }
+ first = false
+ for i := start; i < end; {
+ i = t.executeElement(i, newst)
+ }
+ }
+
+ if array, ok := field.(reflect.ArrayOrSliceValue); ok {
+ for j := 0; j < array.Len(); j++ {
+ loopBody(st.clone(array.Elem(j)))
+ }
+ } else if m, ok := field.(*reflect.MapValue); ok {
+ for _, key := range m.Keys() {
+ loopBody(st.clone(m.Elem(key)))
+ }
+ } else if ch := iter(field); ch != nil {
+ for {
+ e, ok := ch.Recv()
+ if !ok {
+ break
+ }
+ loopBody(st.clone(e))
+ }
+ } else {
+ t.execError(st, r.linenum, ".repeated: cannot repeat %s (type %s)",
+ r.field, field.Type())
+ }
+
+ if first {
+ // Empty. Execute the .or block, once. If it's missing, do nothing.
+ start, end := r.or, r.end
+ if start >= 0 {
+ newst := st.clone(field)
+ for i := start; i < end; {
+ i = t.executeElement(i, newst)
+ }
+ }
+ return
+ }
+}
+
+// A valid delimiter must contain no white space and be non-empty.
+func validDelim(d []byte) bool {
+ if len(d) == 0 {
+ return false
+ }
+ for _, c := range d {
+ if white(c) {
+ return false
+ }
+ }
+ return true
+}
+
+// checkError is a deferred function to turn a panic with type *Error into a plain error return.
+// Other panics are unexpected and so are re-enabled.
+func checkError(error *os.Error) {
+ if v := recover(); v != nil {
+ if e, ok := v.(*Error); ok {
+ *error = e
+ } else {
+ // runtime errors should crash
+ panic(v)
+ }
+ }
+}
+
+// -- Public interface
+
+// Parse initializes a Template by parsing its definition. The string
+// s contains the template text. If any errors occur, Parse returns
+// the error.
+func (t *Template) Parse(s string) (err os.Error) {
+ if t.elems == nil {
+ return &Error{1, "template not allocated with New"}
+ }
+ if !validDelim(t.ldelim) || !validDelim(t.rdelim) {
+ return &Error{1, fmt.Sprintf("bad delimiter strings %q %q", t.ldelim, t.rdelim)}
+ }
+ defer checkError(&err)
+ t.buf = []byte(s)
+ t.p = 0
+ t.linenum = 1
+ t.parse()
+ return nil
+}
+
+// ParseFile is like Parse but reads the template definition from the
+// named file.
+func (t *Template) ParseFile(filename string) (err os.Error) {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return err
+ }
+ return t.Parse(string(b))
+}
+
+// Execute applies a parsed template to the specified data object,
+// generating output to wr.
+func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) {
+ // Extract the driver data.
+ val := reflect.NewValue(data)
+ defer checkError(&err)
+ t.p = 0
+ t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr})
+ return nil
+}
+
+// SetDelims sets the left and right delimiters for operations in the
+// template. They are validated during parsing. They could be
+// validated here but it's better to keep the routine simple. The
+// delimiters are very rarely invalid and Parse has the necessary
+// error-handling interface already.
+func (t *Template) SetDelims(left, right string) {
+ t.ldelim = []byte(left)
+ t.rdelim = []byte(right)
+}
+
+// Parse creates a Template with default parameters (such as {} for
+// metacharacters). The string s contains the template text while
+// the formatter map fmap, which may be nil, defines auxiliary functions
+// for formatting variables. The template is returned. If any errors
+// occur, err will be non-nil.
+func Parse(s string, fmap FormatterMap) (t *Template, err os.Error) {
+ t = New(fmap)
+ err = t.Parse(s)
+ if err != nil {
+ t = nil
+ }
+ return
+}
+
+// ParseFile is a wrapper function that creates a Template with default
+// parameters (such as {} for metacharacters). The filename identifies
+// a file containing the template text, while the formatter map fmap, which
+// may be nil, defines auxiliary functions for formatting variables.
+// The template is returned. If any errors occur, err will be non-nil.
+func ParseFile(filename string, fmap FormatterMap) (t *Template, err os.Error) {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ return Parse(string(b), fmap)
+}
+
+// MustParse is like Parse but panics if the template cannot be parsed.
+func MustParse(s string, fmap FormatterMap) *Template {
+ t, err := Parse(s, fmap)
+ if err != nil {
+ panic("template.MustParse error: " + err.String())
+ }
+ return t
+}
+
+// MustParseFile is like ParseFile but panics if the file cannot be read
+// or the template cannot be parsed.
+func MustParseFile(filename string, fmap FormatterMap) *Template {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ panic("template.MustParseFile error: " + err.String())
+ }
+ return MustParse(string(b), fmap)
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ Data-driven templates for generating textual output such as
+ HTML.
+
+ Templates are executed by applying them to a data structure.
+ Annotations in the template refer to elements of the data
+ structure (typically a field of a struct or a key in a map)
+ to control execution and derive values to be displayed.
+ The template walks the structure as it executes and the
+ "cursor" @ represents the value at the current location
+ in the structure.
+
+ Data items may be values or pointers; the interface hides the
+ indirection.
+
+ In the following, 'field' is one of several things, according to the data.
+
+ - The name of a field of a struct (result = data.field),
+ - The value stored in a map under that key (result = data[field]), or
+ - The result of invoking a niladic single-valued method with that name
+ (result = data.field())
+
+ Major constructs ({} are the default delimiters for template actions;
+ [] are the notation in this comment for optional elements):
+
+ {# comment }
+
+ A one-line comment.
+
+ {.section field} XXX [ {.or} YYY ] {.end}
+
+ Set @ to the value of the field. It may be an explicit @
+ to stay at the same point in the data. If the field is nil
+ or empty, execute YYY; otherwise execute XXX.
+
+ {.repeated section field} XXX [ {.alternates with} ZZZ ] [ {.or} YYY ] {.end}
+
+ Like .section, but field must be an array or slice. XXX
+ is executed for each element. If the array is nil or empty,
+ YYY is executed instead. If the {.alternates with} marker
+ is present, ZZZ is executed between iterations of XXX.
+
+ {field}
+ {field1 field2 ...}
+ {field|formatter}
+ {field1 field2...|formatter}
+ {field|formatter1|formatter2}
+
+ Insert the value of the fields into the output. Each field is
+ first looked for in the cursor, as in .section and .repeated.
+ If it is not found, the search continues in outer sections
+ until the top level is reached.
+
+ If the field value is a pointer, leading asterisks indicate
+ that the value to be inserted should be evaluated through the
+ pointer. For example, if x.p is of type *int, {x.p} will
+ insert the value of the pointer but {*x.p} will insert the
+ value of the underlying integer. If the value is nil or not a
+ pointer, asterisks have no effect.
+
+ If a formatter is specified, it must be named in the formatter
+ map passed to the template set up routines or in the default
+ set ("html","str","") and is used to process the data for
+ output. The formatter function has signature
+ func(wr io.Writer, formatter string, data ...interface{})
+ where wr is the destination for output, data holds the field
+ values at the instantiation, and formatter is its name at
+ the invocation site. The default formatter just concatenates
+ the string representations of the fields.
+
+ Multiple formatters separated by the pipeline character | are
+ executed sequentially, with each formatter receiving the bytes
+ emitted by the one to its left.
+
+ The delimiter strings get their default value, "{" and "}", from
+ JSON-template. They may be set to any non-empty, space-free
+ string using the SetDelims method. Their value can be printed
+ in the output using {.meta-left} and {.meta-right}.
+*/
+package template
+
+import (
+ "bytes"
+ "container/vector"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "reflect"
+ "strings"
+ "unicode"
+ "utf8"
+)
+
+// Errors returned during parsing and execution. Users may extract the information and reformat
+// if they desire.
+type Error struct {
+ Line int
+ Msg string
+}
+
+func (e *Error) String() string { return fmt.Sprintf("line %d: %s", e.Line, e.Msg) }
+
+// Most of the literals are aces.
+var lbrace = []byte{'{'}
+var rbrace = []byte{'}'}
+var space = []byte{' '}
+var tab = []byte{'\t'}
+
+// The various types of "tokens", which are plain text or (usually) brace-delimited descriptors
+const (
+ tokAlternates = iota
+ tokComment
+ tokEnd
+ tokLiteral
+ tokOr
+ tokRepeated
+ tokSection
+ tokText
+ tokVariable
+)
+
+// FormatterMap is the type describing the mapping from formatter
+// names to the functions that implement them.
+type FormatterMap map[string]func(io.Writer, string, ...interface{})
+
+// Built-in formatters.
+var builtins = FormatterMap{
+ "html": HTMLFormatter,
+ "str": StringFormatter,
+ "": StringFormatter,
+}
+
+// The parsed state of a template is a vector of xxxElement structs.
+// Sections have line numbers so errors can be reported better during execution.
+
+// Plain text.
+type textElement struct {
+ text []byte
+}
+
+// A literal such as .meta-left or .meta-right
+type literalElement struct {
+ text []byte
+}
+
+// A variable invocation to be evaluated
+type variableElement struct {
+ linenum int
+ word []string // The fields in the invocation.
+ fmts []string // Names of formatters to apply. len(fmts) > 0
+}
+
+// A .section block, possibly with a .or
+type sectionElement struct {
+ linenum int // of .section itself
+ field string // cursor field for this block
+ start int // first element
+ or int // first element of .or block
+ end int // one beyond last element
+}
+
+// A .repeated block, possibly with a .or and a .alternates
+type repeatedElement struct {
+ sectionElement // It has the same structure...
+ altstart int // ... except for alternates
+ altend int
+}
+
+// Template is the type that represents a template definition.
+// It is unchanged after parsing.
+type Template struct {
+ fmap FormatterMap // formatters for variables
+ // Used during parsing:
+ ldelim, rdelim []byte // delimiters; default {}
+ buf []byte // input text to process
+ p int // position in buf
+ linenum int // position in input
+ // Parsed results:
+ elems *vector.Vector
+}
+
+// Internal state for executing a Template. As we evaluate the struct,
+// the data item descends into the fields associated with sections, etc.
+// Parent is used to walk upwards to find variables higher in the tree.
+type state struct {
+ parent *state // parent in hierarchy
+ data reflect.Value // the driver data for this section etc.
+ wr io.Writer // where to send output
+ buf [2]bytes.Buffer // alternating buffers used when chaining formatters
+}
+
+func (parent *state) clone(data reflect.Value) *state {
+ return &state{parent: parent, data: data, wr: parent.wr}
+}
+
+// New creates a new template with the specified formatter map (which
+// may be nil) to define auxiliary functions for formatting variables.
+func New(fmap FormatterMap) *Template {
+ t := new(Template)
+ t.fmap = fmap
+ t.ldelim = lbrace
+ t.rdelim = rbrace
+ t.elems = new(vector.Vector)
+ return t
+}
+
+// Report error and stop executing. The line number must be provided explicitly.
+func (t *Template) execError(st *state, line int, err string, args ...interface{}) {
+ panic(&Error{line, fmt.Sprintf(err, args...)})
+}
+
+// Report error, panic to terminate parsing.
+// The line number comes from the template state.
+func (t *Template) parseError(err string, args ...interface{}) {
+ panic(&Error{t.linenum, fmt.Sprintf(err, args...)})
+}
+
+// Is this an exported - upper case - name?
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// -- Lexical analysis
+
+// Is c a white space character?
+func white(c uint8) bool { return c == ' ' || c == '\t' || c == '\r' || c == '\n' }
+
+// Safely, does s[n:n+len(t)] == t?
+func equal(s []byte, n int, t []byte) bool {
+ b := s[n:]
+ if len(t) > len(b) { // not enough space left for a match.
+ return false
+ }
+ for i, c := range t {
+ if c != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// nextItem returns the next item from the input buffer. If the returned
+// item is empty, we are at EOF. The item will be either a
+// delimited string or a non-empty string between delimited
+// strings. Tokens stop at (but include, if plain text) a newline.
+// Action tokens on a line by themselves drop any space on
+// either side, up to and including the newline.
+func (t *Template) nextItem() []byte {
+ startOfLine := t.p == 0 || t.buf[t.p-1] == '\n'
+ start := t.p
+ var i int
+ newline := func() {
+ t.linenum++
+ i++
+ }
+ // Leading white space up to but not including newline
+ for i = start; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' || !white(t.buf[i]) {
+ break
+ }
+ }
+ leadingSpace := i > start
+ // What's left is nothing, newline, delimited string, or plain text
+ switch {
+ case i == len(t.buf):
+ // EOF; nothing to do
+ case t.buf[i] == '\n':
+ newline()
+ case equal(t.buf, i, t.ldelim):
+ left := i // Start of left delimiter.
+ right := -1 // Will be (immediately after) right delimiter.
+ haveText := false // Delimiters contain text.
+ i += len(t.ldelim)
+ // Find the end of the action.
+ for ; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' {
+ break
+ }
+ if equal(t.buf, i, t.rdelim) {
+ i += len(t.rdelim)
+ right = i
+ break
+ }
+ haveText = true
+ }
+ if right < 0 {
+ t.parseError("unmatched opening delimiter")
+ return nil
+ }
+ // Is this a special action (starts with '.' or '#') and the only thing on the line?
+ if startOfLine && haveText {
+ firstChar := t.buf[left+len(t.ldelim)]
+ if firstChar == '.' || firstChar == '#' {
+ // It's special and the first thing on the line. Is it the last?
+ for j := right; j < len(t.buf) && white(t.buf[j]); j++ {
+ if t.buf[j] == '\n' {
+ // Yes it is. Drop the surrounding space and return the {.foo}
+ t.linenum++
+ t.p = j + 1
+ return t.buf[left:right]
+ }
+ }
+ }
+ }
+ // No it's not. If there's leading space, return that.
+ if leadingSpace {
+ // not trimming space: return leading white space if there is some.
+ t.p = left
+ return t.buf[start:left]
+ }
+ // Return the word, leave the trailing space.
+ start = left
+ break
+ default:
+ for ; i < len(t.buf); i++ {
+ if t.buf[i] == '\n' {
+ newline()
+ break
+ }
+ if equal(t.buf, i, t.ldelim) {
+ break
+ }
+ }
+ }
+ item := t.buf[start:i]
+ t.p = i
+ return item
+}
+
+// Turn a byte array into a white-space-split array of strings.
+func words(buf []byte) []string {
+ s := make([]string, 0, 5)
+ p := 0 // position in buf
+ // one word per loop
+ for i := 0; ; i++ {
+ // skip white space
+ for ; p < len(buf) && white(buf[p]); p++ {
+ }
+ // grab word
+ start := p
+ for ; p < len(buf) && !white(buf[p]); p++ {
+ }
+ if start == p { // no text left
+ break
+ }
+ s = append(s, string(buf[start:p]))
+ }
+ return s
+}
+
+// Analyze an item and return its token type and, if it's an action item, an array of
+// its constituent words.
+func (t *Template) analyze(item []byte) (tok int, w []string) {
+ // item is known to be non-empty
+ if !equal(item, 0, t.ldelim) { // doesn't start with left delimiter
+ tok = tokText
+ return
+ }
+ if !equal(item, len(item)-len(t.rdelim), t.rdelim) { // doesn't end with right delimiter
+ t.parseError("internal error: unmatched opening delimiter") // lexing should prevent this
+ return
+ }
+ if len(item) <= len(t.ldelim)+len(t.rdelim) { // no contents
+ t.parseError("empty directive")
+ return
+ }
+ // Comment
+ if item[len(t.ldelim)] == '#' {
+ tok = tokComment
+ return
+ }
+ // Split into words
+ w = words(item[len(t.ldelim) : len(item)-len(t.rdelim)]) // drop final delimiter
+ if len(w) == 0 {
+ t.parseError("empty directive")
+ return
+ }
+ if len(w) > 0 && w[0][0] != '.' {
+ tok = tokVariable
+ return
+ }
+ switch w[0] {
+ case ".meta-left", ".meta-right", ".space", ".tab":
+ tok = tokLiteral
+ return
+ case ".or":
+ tok = tokOr
+ return
+ case ".end":
+ tok = tokEnd
+ return
+ case ".section":
+ if len(w) != 2 {
+ t.parseError("incorrect fields for .section: %s", item)
+ return
+ }
+ tok = tokSection
+ return
+ case ".repeated":
+ if len(w) != 3 || w[1] != "section" {
+ t.parseError("incorrect fields for .repeated: %s", item)
+ return
+ }
+ tok = tokRepeated
+ return
+ case ".alternates":
+ if len(w) != 2 || w[1] != "with" {
+ t.parseError("incorrect fields for .alternates: %s", item)
+ return
+ }
+ tok = tokAlternates
+ return
+ }
+ t.parseError("bad directive: %s", item)
+ return
+}
+
+// formatter returns the Formatter with the given name in the Template, or nil if none exists.
+func (t *Template) formatter(name string) func(io.Writer, string, ...interface{}) {
+ if t.fmap != nil {
+ if fn := t.fmap[name]; fn != nil {
+ return fn
+ }
+ }
+ return builtins[name]
+}
+
+// -- Parsing
+
+// Allocate a new variable-evaluation element.
+func (t *Template) newVariable(words []string) *variableElement {
+ // After the final space-separated argument, formatters may be specified separated
+ // by pipe symbols, for example: {a b c|d|e}
+
+ // Until we learn otherwise, formatters contains a single name: "", the default formatter.
+ formatters := []string{""}
+ lastWord := words[len(words)-1]
+ bar := strings.IndexRune(lastWord, '|')
+ if bar >= 0 {
+ words[len(words)-1] = lastWord[0:bar]
+ formatters = strings.Split(lastWord[bar+1:], "|", -1)
+ }
+
+ // We could remember the function address here and avoid the lookup later,
+ // but it's more dynamic to let the user change the map contents underfoot.
+ // We do require the name to be present, though.
+
+ // Is it in user-supplied map?
+ for _, f := range formatters {
+ if t.formatter(f) == nil {
+ t.parseError("unknown formatter: %q", f)
+ }
+ }
+ return &variableElement{t.linenum, words, formatters}
+}
+
+// Grab the next item. If it's simple, just append it to the template.
+// Otherwise return its details.
+func (t *Template) parseSimple(item []byte) (done bool, tok int, w []string) {
+ tok, w = t.analyze(item)
+ done = true // assume for simplicity
+ switch tok {
+ case tokComment:
+ return
+ case tokText:
+ t.elems.Push(&textElement{item})
+ return
+ case tokLiteral:
+ switch w[0] {
+ case ".meta-left":
+ t.elems.Push(&literalElement{t.ldelim})
+ case ".meta-right":
+ t.elems.Push(&literalElement{t.rdelim})
+ case ".space":
+ t.elems.Push(&literalElement{space})
+ case ".tab":
+ t.elems.Push(&literalElement{tab})
+ default:
+ t.parseError("internal error: unknown literal: %s", w[0])
+ }
+ return
+ case tokVariable:
+ t.elems.Push(t.newVariable(w))
+ return
+ }
+ return false, tok, w
+}
+
+// parseRepeated and parseSection are mutually recursive
+
+func (t *Template) parseRepeated(words []string) *repeatedElement {
+ r := new(repeatedElement)
+ t.elems.Push(r)
+ r.linenum = t.linenum
+ r.field = words[2]
+ // Scan section, collecting true and false (.or) blocks.
+ r.start = t.elems.Len()
+ r.or = -1
+ r.altstart = -1
+ r.altend = -1
+Loop:
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ t.parseError("missing .end for .repeated section")
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokEnd:
+ break Loop
+ case tokOr:
+ if r.or >= 0 {
+ t.parseError("extra .or in .repeated section")
+ break Loop
+ }
+ r.altend = t.elems.Len()
+ r.or = t.elems.Len()
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ case tokAlternates:
+ if r.altstart >= 0 {
+ t.parseError("extra .alternates in .repeated section")
+ break Loop
+ }
+ if r.or >= 0 {
+ t.parseError(".alternates inside .or block in .repeated section")
+ break Loop
+ }
+ r.altstart = t.elems.Len()
+ default:
+ t.parseError("internal error: unknown repeated section item: %s", item)
+ break Loop
+ }
+ }
+ if r.altend < 0 {
+ r.altend = t.elems.Len()
+ }
+ r.end = t.elems.Len()
+ return r
+}
+
+func (t *Template) parseSection(words []string) *sectionElement {
+ s := new(sectionElement)
+ t.elems.Push(s)
+ s.linenum = t.linenum
+ s.field = words[1]
+ // Scan section, collecting true and false (.or) blocks.
+ s.start = t.elems.Len()
+ s.or = -1
+Loop:
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ t.parseError("missing .end for .section")
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokEnd:
+ break Loop
+ case tokOr:
+ if s.or >= 0 {
+ t.parseError("extra .or in .section")
+ break Loop
+ }
+ s.or = t.elems.Len()
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ case tokAlternates:
+ t.parseError(".alternates not in .repeated")
+ default:
+ t.parseError("internal error: unknown section item: %s", item)
+ }
+ }
+ s.end = t.elems.Len()
+ return s
+}
+
+func (t *Template) parse() {
+ for {
+ item := t.nextItem()
+ if len(item) == 0 {
+ break
+ }
+ done, tok, w := t.parseSimple(item)
+ if done {
+ continue
+ }
+ switch tok {
+ case tokOr, tokEnd, tokAlternates:
+ t.parseError("unexpected %s", w[0])
+ case tokSection:
+ t.parseSection(w)
+ case tokRepeated:
+ t.parseRepeated(w)
+ default:
+ t.parseError("internal error: bad directive in parse: %s", item)
+ }
+ }
+}
+
+// -- Execution
+
+// Evaluate interfaces and pointers looking for a value that can look up the name, via a
+// struct field, method, or map key, and return the result of the lookup.
+func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value {
+ for v.IsValid() {
+ typ := v.Type()
+ if n := v.Type().NumMethod(); n > 0 {
+ for i := 0; i < n; i++ {
+ m := typ.Method(i)
+ mtyp := m.Type
+ if m.Name == name && mtyp.NumIn() == 1 && mtyp.NumOut() == 1 {
+ if !isExported(name) {
+ t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type())
+ }
+ return v.Method(i).Call(nil)[0]
+ }
+ }
+ }
+ switch av := v; av.Kind() {
+ case reflect.Ptr:
+ v = av.Elem()
+ case reflect.Interface:
+ v = av.Elem()
+ case reflect.Struct:
+ if !isExported(name) {
+ t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type())
+ }
+ return av.FieldByName(name)
+ case reflect.Map:
+ if v := av.MapIndex(reflect.NewValue(name)); v.IsValid() {
+ return v
+ }
+ return reflect.Zero(typ.Elem())
+ default:
+ return reflect.Value{}
+ }
+ }
+ return v
+}
+
+// indirectPtr returns the item numLevels levels of indirection below the value.
+// It is forgiving: if the value is not a pointer, it returns it rather than giving
+// an error. If the pointer is nil, it is returned as is.
+func indirectPtr(v reflect.Value, numLevels int) reflect.Value {
+ for i := numLevels; v.IsValid() && i > 0; i++ {
+ if p := v; p.Kind() == reflect.Ptr {
+ if p.IsNil() {
+ return v
+ }
+ v = p.Elem()
+ } else {
+ break
+ }
+ }
+ return v
+}
+
+// Walk v through pointers and interfaces, extracting the elements within.
+func indirect(v reflect.Value) reflect.Value {
+loop:
+ for v.IsValid() {
+ switch av := v; av.Kind() {
+ case reflect.Ptr:
+ v = av.Elem()
+ case reflect.Interface:
+ v = av.Elem()
+ default:
+ break loop
+ }
+ }
+ return v
+}
+
+// If the data for this template is a struct, find the named variable.
+// Names of the form a.b.c are walked down the data tree.
+// The special name "@" (the "cursor") denotes the current data.
+// The value coming in (st.data) might need indirecting to reach
+// a struct while the return value is not indirected - that is,
+// it represents the actual named field. Leading stars indicate
+// levels of indirection to be applied to the value.
+func (t *Template) findVar(st *state, s string) reflect.Value {
+ data := st.data
+ flattenedName := strings.TrimLeft(s, "*")
+ numStars := len(s) - len(flattenedName)
+ s = flattenedName
+ if s == "@" {
+ return indirectPtr(data, numStars)
+ }
+ for _, elem := range strings.Split(s, ".", -1) {
+ // Look up field; data must be a struct or map.
+ data = t.lookup(st, data, elem)
+ if !data.IsValid() {
+ return reflect.Value{}
+ }
+ }
+ return indirectPtr(data, numStars)
+}
+
+// Is there no data to look at?
+func empty(v reflect.Value) bool {
+ v = indirect(v)
+ if !v.IsValid() {
+ return true
+ }
+ switch v.Kind() {
+ case reflect.Bool:
+ return v.Bool() == false
+ case reflect.String:
+ return v.String() == ""
+ case reflect.Struct:
+ return false
+ case reflect.Map:
+ return false
+ case reflect.Array:
+ return v.Len() == 0
+ case reflect.Slice:
+ return v.Len() == 0
+ }
+ return false
+}
+
+// Look up a variable or method, up through the parent if necessary.
+func (t *Template) varValue(name string, st *state) reflect.Value {
+ field := t.findVar(st, name)
+ if !field.IsValid() {
+ if st.parent == nil {
+ t.execError(st, t.linenum, "name not found: %s in type %s", name, st.data.Type())
+ }
+ return t.varValue(name, st.parent)
+ }
+ return field
+}
+
+func (t *Template) format(wr io.Writer, fmt string, val []interface{}, v *variableElement, st *state) {
+ fn := t.formatter(fmt)
+ if fn == nil {
+ t.execError(st, v.linenum, "missing formatter %s for variable %s", fmt, v.word[0])
+ }
+ fn(wr, fmt, val...)
+}
+
+// Evaluate a variable, looking up through the parent if necessary.
+// If it has a formatter attached ({var|formatter}) run that too.
+func (t *Template) writeVariable(v *variableElement, st *state) {
+ // Turn the words of the invocation into values.
+ val := make([]interface{}, len(v.word))
+ for i, word := range v.word {
+ val[i] = t.varValue(word, st).Interface()
+ }
+
+ for i, fmt := range v.fmts[:len(v.fmts)-1] {
+ b := &st.buf[i&1]
+ b.Reset()
+ t.format(b, fmt, val, v, st)
+ val = val[0:1]
+ val[0] = b.Bytes()
+ }
+ t.format(st.wr, v.fmts[len(v.fmts)-1], val, v, st)
+}
+
+// Execute element i. Return next index to execute.
+func (t *Template) executeElement(i int, st *state) int {
+ switch elem := t.elems.At(i).(type) {
+ case *textElement:
+ st.wr.Write(elem.text)
+ return i + 1
+ case *literalElement:
+ st.wr.Write(elem.text)
+ return i + 1
+ case *variableElement:
+ t.writeVariable(elem, st)
+ return i + 1
+ case *sectionElement:
+ t.executeSection(elem, st)
+ return elem.end
+ case *repeatedElement:
+ t.executeRepeated(elem, st)
+ return elem.end
+ }
+ e := t.elems.At(i)
+ t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.NewValue(e).Interface(), e)
+ return 0
+}
+
+// Execute the template.
+func (t *Template) execute(start, end int, st *state) {
+ for i := start; i < end; {
+ i = t.executeElement(i, st)
+ }
+}
+
+// Execute a .section
+func (t *Template) executeSection(s *sectionElement, st *state) {
+ // Find driver data for this section. It must be in the current struct.
+ field := t.varValue(s.field, st)
+ if !field.IsValid() {
+ t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, st.data.Type())
+ }
+ st = st.clone(field)
+ start, end := s.start, s.or
+ if !empty(field) {
+ // Execute the normal block.
+ if end < 0 {
+ end = s.end
+ }
+ } else {
+ // Execute the .or block. If it's missing, do nothing.
+ start, end = s.or, s.end
+ if start < 0 {
+ return
+ }
+ }
+ for i := start; i < end; {
+ i = t.executeElement(i, st)
+ }
+}
+
+// Return the result of calling the Iter method on v, or nil.
+func iter(v reflect.Value) reflect.Value {
+ for j := 0; j < v.Type().NumMethod(); j++ {
+ mth := v.Type().Method(j)
+ fv := v.Method(j)
+ ft := fv.Type()
+ // TODO(rsc): NumIn() should return 0 here, because ft is from a curried FuncValue.
+ if mth.Name != "Iter" || ft.NumIn() != 1 || ft.NumOut() != 1 {
+ continue
+ }
+ ct := ft.Out(0)
+ if ct.Kind() != reflect.Chan ||
+ ct.ChanDir()&reflect.RecvDir == 0 {
+ continue
+ }
+ return fv.Call(nil)[0]
+ }
+ return reflect.Value{}
+}
+
+// Execute a .repeated section
+func (t *Template) executeRepeated(r *repeatedElement, st *state) {
+ // Find driver data for this section. It must be in the current struct.
+ field := t.varValue(r.field, st)
+ if !field.IsValid() {
+ t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, st.data.Type())
+ }
+ field = indirect(field)
+
+ start, end := r.start, r.or
+ if end < 0 {
+ end = r.end
+ }
+ if r.altstart >= 0 {
+ end = r.altstart
+ }
+ first := true
+
+ // Code common to all the loops.
+ loopBody := func(newst *state) {
+ // .alternates between elements
+ if !first && r.altstart >= 0 {
+ for i := r.altstart; i < r.altend; {
+ i = t.executeElement(i, newst)
+ }
+ }
+ first = false
+ for i := start; i < end; {
+ i = t.executeElement(i, newst)
+ }
+ }
+
+ if array := field; array.Kind() == reflect.Array || array.Kind() == reflect.Slice {
+ for j := 0; j < array.Len(); j++ {
+ loopBody(st.clone(array.Index(j)))
+ }
+ } else if m := field; m.Kind() == reflect.Map {
+ for _, key := range m.MapKeys() {
+ loopBody(st.clone(m.MapIndex(key)))
+ }
+ } else if ch := iter(field); ch.IsValid() {
+ for {
+ e, ok := ch.Recv()
+ if !ok {
+ break
+ }
+ loopBody(st.clone(e))
+ }
+ } else {
+ t.execError(st, r.linenum, ".repeated: cannot repeat %s (type %s)",
+ r.field, field.Type())
+ }
+
+ if first {
+ // Empty. Execute the .or block, once. If it's missing, do nothing.
+ start, end := r.or, r.end
+ if start >= 0 {
+ newst := st.clone(field)
+ for i := start; i < end; {
+ i = t.executeElement(i, newst)
+ }
+ }
+ return
+ }
+}
+
+// A valid delimiter must contain no white space and be non-empty.
+func validDelim(d []byte) bool {
+ if len(d) == 0 {
+ return false
+ }
+ for _, c := range d {
+ if white(c) {
+ return false
+ }
+ }
+ return true
+}
+
+// checkError is a deferred function to turn a panic with type *Error into a plain error return.
+// Other panics are unexpected and so are re-enabled.
+func checkError(error *os.Error) {
+ if v := recover(); v != nil {
+ if e, ok := v.(*Error); ok {
+ *error = e
+ } else {
+ // runtime errors should crash
+ panic(v)
+ }
+ }
+}
+
+// -- Public interface
+
+// Parse initializes a Template by parsing its definition. The string
+// s contains the template text. If any errors occur, Parse returns
+// the error.
+func (t *Template) Parse(s string) (err os.Error) {
+ if t.elems == nil {
+ return &Error{1, "template not allocated with New"}
+ }
+ if !validDelim(t.ldelim) || !validDelim(t.rdelim) {
+ return &Error{1, fmt.Sprintf("bad delimiter strings %q %q", t.ldelim, t.rdelim)}
+ }
+ defer checkError(&err)
+ t.buf = []byte(s)
+ t.p = 0
+ t.linenum = 1
+ t.parse()
+ return nil
+}
+
+// ParseFile is like Parse but reads the template definition from the
+// named file.
+func (t *Template) ParseFile(filename string) (err os.Error) {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return err
+ }
+ return t.Parse(string(b))
+}
+
+// Execute applies a parsed template to the specified data object,
+// generating output to wr.
+func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) {
+ // Extract the driver data.
+ val := reflect.NewValue(data)
+ defer checkError(&err)
+ t.p = 0
+ t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr})
+ return nil
+}
+
+// SetDelims sets the left and right delimiters for operations in the
+// template. They are validated during parsing. They could be
+// validated here but it's better to keep the routine simple. The
+// delimiters are very rarely invalid and Parse has the necessary
+// error-handling interface already.
+func (t *Template) SetDelims(left, right string) {
+ t.ldelim = []byte(left)
+ t.rdelim = []byte(right)
+}
+
+// Parse creates a Template with default parameters (such as {} for
+// metacharacters). The string s contains the template text while
+// the formatter map fmap, which may be nil, defines auxiliary functions
+// for formatting variables. The template is returned. If any errors
+// occur, err will be non-nil.
+func Parse(s string, fmap FormatterMap) (t *Template, err os.Error) {
+ t = New(fmap)
+ err = t.Parse(s)
+ if err != nil {
+ t = nil
+ }
+ return
+}
+
+// ParseFile is a wrapper function that creates a Template with default
+// parameters (such as {} for metacharacters). The filename identifies
+// a file containing the template text, while the formatter map fmap, which
+// may be nil, defines auxiliary functions for formatting variables.
+// The template is returned. If any errors occur, err will be non-nil.
+func ParseFile(filename string, fmap FormatterMap) (t *Template, err os.Error) {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ return Parse(string(b), fmap)
+}
+
+// MustParse is like Parse but panics if the template cannot be parsed.
+func MustParse(s string, fmap FormatterMap) *Template {
+ t, err := Parse(s, fmap)
+ if err != nil {
+ panic("template.MustParse error: " + err.String())
+ }
+ return t
+}
+
+// MustParseFile is like ParseFile but panics if the file cannot be read
+// or the template cannot be parsed.
+func MustParseFile(filename string, fmap FormatterMap) *Template {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ panic("template.MustParseFile error: " + err.String())
+ }
+ return MustParse(string(b), fmap)
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+ "sync"
+ "unicode"
+ "utf8"
+)
+
+// userTypeInfo stores the information associated with a type the user has handed
+// to the package. It's computed once and stored in a map keyed by reflection
+// type.
+type userTypeInfo struct {
+ user reflect.Type // the type the user handed us
+ base reflect.Type // the base type after all indirections
+ indir int // number of indirections to reach the base type
+ isGobEncoder bool // does the type implement GobEncoder?
+ isGobDecoder bool // does the type implement GobDecoder?
+ encIndir int8 // number of indirections to reach the receiver type; may be negative
+ decIndir int8 // number of indirections to reach the receiver type; may be negative
+}
+
+var (
+ // Protected by an RWMutex because we read it a lot and write
+ // it only when we see a new type, typically when compiling.
+ userTypeLock sync.RWMutex
+ userTypeCache = make(map[reflect.Type]*userTypeInfo)
+)
+
+// validType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, err will be non-nil. To be used when the error handler
+// is not set up.
+func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
+ userTypeLock.RLock()
+ ut = userTypeCache[rt]
+ userTypeLock.RUnlock()
+ if ut != nil {
+ return
+ }
+ // Now set the value under the write lock.
+ userTypeLock.Lock()
+ defer userTypeLock.Unlock()
+ if ut = userTypeCache[rt]; ut != nil {
+ // Lost the race; not a problem.
+ return
+ }
+ ut = new(userTypeInfo)
+ ut.base = rt
+ ut.user = rt
+ // A type that is just a cycle of pointers (such as type T *T) cannot
+ // be represented in gobs, which need some concrete data. We use a
+ // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
+ // pp 539-540. As we step through indirections, run another type at
+ // half speed. If they meet up, there's a cycle.
+ slowpoke := ut.base // walks half as fast as ut.base
+ for {
+ pt, ok := ut.base.(*reflect.PtrType)
+ if !ok {
+ break
+ }
+ ut.base = pt.Elem()
+ if ut.base == slowpoke { // ut.base lapped slowpoke
+ // recursive pointer type.
+ return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String())
+ }
+ if ut.indir%2 == 0 {
+ slowpoke = slowpoke.(*reflect.PtrType).Elem()
+ }
+ ut.indir++
+ }
+ ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck)
+ ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck)
+ userTypeCache[rt] = ut
+ return
+}
+
+const (
+ gobEncodeMethodName = "GobEncode"
+ gobDecodeMethodName = "GobDecode"
+)
+
+// implements returns whether the type implements the interface, as encoded
+// in the check function.
+func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool {
+ if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance
+ return false
+ }
+ return check(typ)
+}
+
+// gobEncoderCheck makes the type assertion a boolean function.
+func gobEncoderCheck(typ reflect.Type) bool {
+ _, ok := reflect.MakeZero(typ).Interface().(GobEncoder)
+ return ok
+}
+
+// gobDecoderCheck makes the type assertion a boolean function.
+func gobDecoderCheck(typ reflect.Type) bool {
+ _, ok := reflect.MakeZero(typ).Interface().(GobDecoder)
+ return ok
+}
+
+// implementsInterface reports whether the type implements the
+// interface. (The actual check is done through the provided function.)
+// It also returns the number of indirections required to get to the
+// implementation.
+func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) {
+ if typ == nil {
+ return
+ }
+ rt := typ
+ // The type might be a pointer and we need to keep
+ // dereferencing to the base type until we find an implementation.
+ for {
+ if implements(rt, check) {
+ return true, indir
+ }
+ if p, ok := rt.(*reflect.PtrType); ok {
+ indir++
+ if indir > 100 { // insane number of indirections
+ return false, 0
+ }
+ rt = p.Elem()
+ continue
+ }
+ break
+ }
+ // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy.
+ if _, ok := typ.(*reflect.PtrType); !ok {
+ // Not a pointer, but does the pointer work?
+ if implements(reflect.PtrTo(typ), check) {
+ return true, -1
+ }
+ }
+ return false, 0
+}
+
+// userType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, it calls error.
+func userType(rt reflect.Type) *userTypeInfo {
+ ut, err := validUserType(rt)
+ if err != nil {
+ error(err)
+ }
+ return ut
+}
+// A typeId represents a gob Type as an integer that can be passed on the wire.
+// Internally, typeIds are used as keys to a map to recover the underlying type info.
+type typeId int32
+
+var nextId typeId // incremented for each new type we build
+var typeLock sync.Mutex // set while building a type
+const firstUserId = 64 // lowest id number granted to user
+
+type gobType interface {
+ id() typeId
+ setId(id typeId)
+ name() string
+ string() string // not public; only for debugging
+ safeString(seen map[typeId]bool) string
+}
+
+var types = make(map[reflect.Type]gobType)
+var idToType = make(map[typeId]gobType)
+var builtinIdToType map[typeId]gobType // set in init() after builtins are established
+
+func setTypeId(typ gobType) {
+ nextId++
+ typ.setId(nextId)
+ idToType[nextId] = typ
+}
+
+func (t typeId) gobType() gobType {
+ if t == 0 {
+ return nil
+ }
+ return idToType[t]
+}
+
+// string returns the string representation of the type associated with the typeId.
+func (t typeId) string() string {
+ if t.gobType() == nil {
+ return "<nil>"
+ }
+ return t.gobType().string()
+}
+
+// Name returns the name of the type associated with the typeId.
+func (t typeId) name() string {
+ if t.gobType() == nil {
+ return "<nil>"
+ }
+ return t.gobType().name()
+}
+
+// Common elements of all types.
+type CommonType struct {
+ Name string
+ Id typeId
+}
+
+func (t *CommonType) id() typeId { return t.Id }
+
+func (t *CommonType) setId(id typeId) { t.Id = id }
+
+func (t *CommonType) string() string { return t.Name }
+
+func (t *CommonType) safeString(seen map[typeId]bool) string {
+ return t.Name
+}
+
+func (t *CommonType) name() string { return t.Name }
+
+// Create and check predefined types
+// The string for tBytes is "bytes" not "[]byte" to signify its specialness.
+
+var (
+ // Primordial types, needed during initialization.
+ // Always passed as pointers so the interface{} type
+ // goes through without losing its interfaceness.
+ tBool = bootstrapType("bool", (*bool)(nil), 1)
+ tInt = bootstrapType("int", (*int)(nil), 2)
+ tUint = bootstrapType("uint", (*uint)(nil), 3)
+ tFloat = bootstrapType("float", (*float64)(nil), 4)
+ tBytes = bootstrapType("bytes", (*[]byte)(nil), 5)
+ tString = bootstrapType("string", (*string)(nil), 6)
+ tComplex = bootstrapType("complex", (*complex128)(nil), 7)
+ tInterface = bootstrapType("interface", (*interface{})(nil), 8)
+ // Reserve some Ids for compatible expansion
+ tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil), 9)
+ tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil), 10)
+ tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil), 11)
+ tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil), 12)
+ tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil), 13)
+ tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil), 14)
+ tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil), 15)
+)
+
+// Predefined because it's needed by the Decoder
+var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id
+var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType)
+
+func init() {
+ // Some magic numbers to make sure there are no surprises.
+ checkId(16, tWireType)
+ checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id)
+ checkId(18, mustGetTypeInfo(reflect.Typeof(CommonType{})).id)
+ checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id)
+ checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id)
+ checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id)
+ checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id)
+
+ builtinIdToType = make(map[typeId]gobType)
+ for k, v := range idToType {
+ builtinIdToType[k] = v
+ }
+
+ // Move the id space upwards to allow for growth in the predefined world
+ // without breaking existing files.
+ if nextId > firstUserId {
+ panic(fmt.Sprintln("nextId too large:", nextId))
+ }
+ nextId = firstUserId
+ registerBasics()
+ wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil)))
+}
+
+// Array type
+type arrayType struct {
+ CommonType
+ Elem typeId
+ Len int
+}
+
+func newArrayType(name string) *arrayType {
+ a := &arrayType{CommonType{Name: name}, 0, 0}
+ return a
+}
+
+func (a *arrayType) init(elem gobType, len int) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(a)
+ a.Elem = elem.id()
+ a.Len = len
+}
+
+func (a *arrayType) safeString(seen map[typeId]bool) string {
+ if seen[a.Id] {
+ return a.Name
+ }
+ seen[a.Id] = true
+ return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen))
+}
+
+func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
+
+// GobEncoder type (something that implements the GobEncoder interface)
+type gobEncoderType struct {
+ CommonType
+}
+
+func newGobEncoderType(name string) *gobEncoderType {
+ g := &gobEncoderType{CommonType{Name: name}}
+ setTypeId(g)
+ return g
+}
+
+func (g *gobEncoderType) safeString(seen map[typeId]bool) string {
+ return g.Name
+}
+
+func (g *gobEncoderType) string() string { return g.Name }
+
+// Map type
+type mapType struct {
+ CommonType
+ Key typeId
+ Elem typeId
+}
+
+func newMapType(name string) *mapType {
+ m := &mapType{CommonType{Name: name}, 0, 0}
+ return m
+}
+
+func (m *mapType) init(key, elem gobType) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(m)
+ m.Key = key.id()
+ m.Elem = elem.id()
+}
+
+func (m *mapType) safeString(seen map[typeId]bool) string {
+ if seen[m.Id] {
+ return m.Name
+ }
+ seen[m.Id] = true
+ key := m.Key.gobType().safeString(seen)
+ elem := m.Elem.gobType().safeString(seen)
+ return fmt.Sprintf("map[%s]%s", key, elem)
+}
+
+func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) }
+
+// Slice type
+type sliceType struct {
+ CommonType
+ Elem typeId
+}
+
+func newSliceType(name string) *sliceType {
+ s := &sliceType{CommonType{Name: name}, 0}
+ return s
+}
+
+func (s *sliceType) init(elem gobType) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(s)
+ s.Elem = elem.id()
+}
+
+func (s *sliceType) safeString(seen map[typeId]bool) string {
+ if seen[s.Id] {
+ return s.Name
+ }
+ seen[s.Id] = true
+ return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen))
+}
+
+func (s *sliceType) string() string { return s.safeString(make(map[typeId]bool)) }
+
+// Struct type
+type fieldType struct {
+ Name string
+ Id typeId
+}
+
+type structType struct {
+ CommonType
+ Field []*fieldType
+}
+
+func (s *structType) safeString(seen map[typeId]bool) string {
+ if s == nil {
+ return "<nil>"
+ }
+ if _, ok := seen[s.Id]; ok {
+ return s.Name
+ }
+ seen[s.Id] = true
+ str := s.Name + " = struct { "
+ for _, f := range s.Field {
+ str += fmt.Sprintf("%s %s; ", f.Name, f.Id.gobType().safeString(seen))
+ }
+ str += "}"
+ return str
+}
+
+func (s *structType) string() string { return s.safeString(make(map[typeId]bool)) }
+
+func newStructType(name string) *structType {
+ s := &structType{CommonType{Name: name}, nil}
+ // For historical reasons we set the id here rather than init.
+ // See the comment in newTypeObject for details.
+ setTypeId(s)
+ return s
+}
+
+// newTypeObject allocates a gobType for the reflection type rt.
+// Unless ut represents a GobEncoder, rt should be the base type
+// of ut.
+// This is only called from the encoding side. The decoding side
+// works through typeIds and userTypeInfos alone.
+func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+ // Does this type implement GobEncoder?
+ if ut.isGobEncoder {
+ return newGobEncoderType(name), nil
+ }
+ var err os.Error
+ var type0, type1 gobType
+ defer func() {
+ if err != nil {
+ types[rt] = nil, false
+ }
+ }()
+ // Install the top-level type before the subtypes (e.g. struct before
+ // fields) so recursive types can be constructed safely.
+ switch t := rt.(type) {
+ // All basic types are easy: they are predefined.
+ case *reflect.BoolType:
+ return tBool.gobType(), nil
+
+ case *reflect.IntType:
+ return tInt.gobType(), nil
+
+ case *reflect.UintType:
+ return tUint.gobType(), nil
+
+ case *reflect.FloatType:
+ return tFloat.gobType(), nil
+
+ case *reflect.ComplexType:
+ return tComplex.gobType(), nil
+
+ case *reflect.StringType:
+ return tString.gobType(), nil
+
+ case *reflect.InterfaceType:
+ return tInterface.gobType(), nil
+
+ case *reflect.ArrayType:
+ at := newArrayType(name)
+ types[rt] = at
+ type0, err = getBaseType("", t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ // Historical aside:
+ // For arrays, maps, and slices, we set the type id after the elements
+ // are constructed. This is to retain the order of type id allocation after
+ // a fix made to handle recursive types, which changed the order in
+ // which types are built. Delaying the setting in this way preserves
+ // type ids while allowing recursive types to be described. Structs,
+ // done below, were already handling recursion correctly so they
+ // assign the top-level id before those of the field.
+ at.init(type0, t.Len())
+ return at, nil
+
+ case *reflect.MapType:
+ mt := newMapType(name)
+ types[rt] = mt
+ type0, err = getBaseType("", t.Key())
+ if err != nil {
+ return nil, err
+ }
+ type1, err = getBaseType("", t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ mt.init(type0, type1)
+ return mt, nil
+
+ case *reflect.SliceType:
+ // []byte == []uint8 is a special case
+ if t.Elem().Kind() == reflect.Uint8 {
+ return tBytes.gobType(), nil
+ }
+ st := newSliceType(name)
+ types[rt] = st
+ type0, err = getBaseType(t.Elem().Name(), t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ st.init(type0)
+ return st, nil
+
+ case *reflect.StructType:
+ st := newStructType(name)
+ types[rt] = st
+ idToType[st.id()] = st
+ for i := 0; i < t.NumField(); i++ {
+ f := t.Field(i)
+ if !isExported(f.Name) {
+ continue
+ }
+ typ := userType(f.Type).base
+ tname := typ.Name()
+ if tname == "" {
+ t := userType(f.Type).base
+ tname = t.String()
+ }
+ gt, err := getBaseType(tname, f.Type)
+ if err != nil {
+ return nil, err
+ }
+ st.Field = append(st.Field, &fieldType{f.Name, gt.id()})
+ }
+ return st, nil
+
+ default:
+ return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String())
+ }
+ return nil, nil
+}
+
+// isExported reports whether this is an exported - upper case - name.
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// getBaseType returns the Gob type describing the given reflect.Type's base type.
+// typeLock must be held.
+func getBaseType(name string, rt reflect.Type) (gobType, os.Error) {
+ ut := userType(rt)
+ return getType(name, ut, ut.base)
+}
+
+// getType returns the Gob type describing the given reflect.Type.
+// Should be called only when handling GobEncoders/Decoders,
+// which may be pointers. All other types are handled through the
+// base type, never a pointer.
+// typeLock must be held.
+func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+ typ, present := types[rt]
+ if present {
+ return typ, nil
+ }
+ typ, err := newTypeObject(name, ut, rt)
+ if err == nil {
+ types[rt] = typ
+ }
+ return typ, err
+}
+
+func checkId(want, got typeId) {
+ if want != got {
+ fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(got), int(want))
+ panic("bootstrap type wrong id: " + got.name() + " " + got.string() + " not " + want.string())
+ }
+}
+
+// used for building the basic types; called only from init(). the incoming
+// interface always refers to a pointer.
+func bootstrapType(name string, e interface{}, expect typeId) typeId {
+ rt := reflect.Typeof(e).(*reflect.PtrType).Elem()
+ _, present := types[rt]
+ if present {
+ panic("bootstrap type already present: " + name + ", " + rt.String())
+ }
+ typ := &CommonType{Name: name}
+ types[rt] = typ
+ setTypeId(typ)
+ checkId(expect, nextId)
+ userType(rt) // might as well cache it now
+ return nextId
+}
+
+// Representation of the information we send and receive about this type.
+// Each value we send is preceded by its type definition: an encoded int.
+// However, the very first time we send the value, we first send the pair
+// (-id, wireType).
+// For bootstrapping purposes, we assume that the recipient knows how
+// to decode a wireType; it is exactly the wireType struct here, interpreted
+// using the gob rules for sending a structure, except that we assume the
+// ids for wireType and structType etc. are known. The relevant pieces
+// are built in encode.go's init() function.
+// To maintain binary compatibility, if you extend this type, always put
+// the new fields last.
+type wireType struct {
+ ArrayT *arrayType
+ SliceT *sliceType
+ StructT *structType
+ MapT *mapType
+ GobEncoderT *gobEncoderType
+}
+
+func (w *wireType) string() string {
+ const unknown = "unknown type"
+ if w == nil {
+ return unknown
+ }
+ switch {
+ case w.ArrayT != nil:
+ return w.ArrayT.Name
+ case w.SliceT != nil:
+ return w.SliceT.Name
+ case w.StructT != nil:
+ return w.StructT.Name
+ case w.MapT != nil:
+ return w.MapT.Name
+ case w.GobEncoderT != nil:
+ return w.GobEncoderT.Name
+ }
+ return unknown
+}
+
+type typeInfo struct {
+ id typeId
+ encoder *encEngine
+ wire *wireType
+}
+
+var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock
+
+// typeLock must be held.
+func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) {
+ rt := ut.base
+ if ut.isGobEncoder {
+ // We want the user type, not the base type.
+ rt = ut.user
+ }
+ info, ok := typeInfoMap[rt]
+ if ok {
+ return info, nil
+ }
+ info = new(typeInfo)
+ gt, err := getBaseType(rt.Name(), rt)
+ if err != nil {
+ return nil, err
+ }
+ info.id = gt.id()
+
+ if ut.isGobEncoder {
+ userType, err := getType(rt.Name(), ut, rt)
+ if err != nil {
+ return nil, err
+ }
+ info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)}
+ typeInfoMap[ut.user] = info
+ return info, nil
+ }
+
+ t := info.id.gobType()
+ switch typ := rt.(type) {
+ case *reflect.ArrayType:
+ info.wire = &wireType{ArrayT: t.(*arrayType)}
+ case *reflect.MapType:
+ info.wire = &wireType{MapT: t.(*mapType)}
+ case *reflect.SliceType:
+ // []byte == []uint8 is a special case handled separately
+ if typ.Elem().Kind() != reflect.Uint8 {
+ info.wire = &wireType{SliceT: t.(*sliceType)}
+ }
+ case *reflect.StructType:
+ info.wire = &wireType{StructT: t.(*structType)}
+ }
+ typeInfoMap[rt] = info
+ return info, nil
+}
+
+// Called only when a panic is acceptable and unexpected.
+func mustGetTypeInfo(rt reflect.Type) *typeInfo {
+ t, err := getTypeInfo(userType(rt))
+ if err != nil {
+ panic("getTypeInfo: " + err.String())
+ }
+ return t
+}
+
+// GobEncoder is the interface describing data that provides its own
+// representation for encoding values for transmission to a GobDecoder.
+// A type that implements GobEncoder and GobDecoder has complete
+// control over the representation of its data and may therefore
+// contain things such as private fields, channels, and functions,
+// which are not usually transmissable in gob streams.
+//
+// Note: Since gobs can be stored permanently, It is good design
+// to guarantee the encoding used by a GobEncoder is stable as the
+// software evolves. For instance, it might make sense for GobEncode
+// to include a version number in the encoding.
+type GobEncoder interface {
+ // GobEncode returns a byte slice representing the encoding of the
+ // receiver for transmission to a GobDecoder, usually of the same
+ // concrete type.
+ GobEncode() ([]byte, os.Error)
+}
+
+// GobDecoder is the interface describing data that provides its own
+// routine for decoding transmitted values sent by a GobEncoder.
+type GobDecoder interface {
+ // GobDecode overwrites the receiver, which must be a pointer,
+ // with the value represented by the byte slice, which was written
+ // by GobEncode, usually for the same concrete type.
+ GobDecode([]byte) os.Error
+}
+
+var (
+ nameToConcreteType = make(map[string]reflect.Type)
+ concreteTypeToName = make(map[reflect.Type]string)
+)
+
+// RegisterName is like Register but uses the provided name rather than the
+// type's default.
+func RegisterName(name string, value interface{}) {
+ if name == "" {
+ // reserved for nil
+ panic("attempt to register empty name")
+ }
+ base := userType(reflect.Typeof(value)).base
+ // Check for incompatible duplicates.
+ if t, ok := nameToConcreteType[name]; ok && t != base {
+ panic("gob: registering duplicate types for " + name)
+ }
+ if n, ok := concreteTypeToName[base]; ok && n != name {
+ panic("gob: registering duplicate names for " + base.String())
+ }
+ // Store the name and type provided by the user....
+ nameToConcreteType[name] = reflect.Typeof(value)
+ // but the flattened type in the type table, since that's what decode needs.
+ concreteTypeToName[base] = name
+}
+
+// Register records a type, identified by a value for that type, under its
+// internal type name. That name will identify the concrete type of a value
+// sent or received as an interface variable. Only types that will be
+// transferred as implementations of interface values need to be registered.
+// Expecting to be used only during initialization, it panics if the mapping
+// between types and names is not a bijection.
+func Register(value interface{}) {
+ // Default to printed representation for unnamed types
+ rt := reflect.Typeof(value)
+ name := rt.String()
+
+ // But for named types (or pointers to them), qualify with import path.
+ // Dereference one pointer looking for a named type.
+ star := ""
+ if rt.Name() == "" {
+ if pt, ok := rt.(*reflect.PtrType); ok {
+ star = "*"
+ rt = pt
+ }
+ }
+ if rt.Name() != "" {
+ if rt.PkgPath() == "" {
+ name = star + rt.Name()
+ } else {
+ name = star + rt.PkgPath() + "." + rt.Name()
+ }
+ }
+
+ RegisterName(name, value)
+}
+
+func registerBasics() {
+ Register(int(0))
+ Register(int8(0))
+ Register(int16(0))
+ Register(int32(0))
+ Register(int64(0))
+ Register(uint(0))
+ Register(uint8(0))
+ Register(uint16(0))
+ Register(uint32(0))
+ Register(uint64(0))
+ Register(float32(0))
+ Register(float64(0))
+ Register(complex64(0i))
+ Register(complex128(0i))
+ Register(false)
+ Register("")
+ Register([]byte(nil))
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gob
+
+import (
+ "fmt"
+ "os"
+ "reflect"
+ "sync"
+ "unicode"
+ "utf8"
+)
+
+// userTypeInfo stores the information associated with a type the user has handed
+// to the package. It's computed once and stored in a map keyed by reflection
+// type.
+type userTypeInfo struct {
+ user reflect.Type // the type the user handed us
+ base reflect.Type // the base type after all indirections
+ indir int // number of indirections to reach the base type
+ isGobEncoder bool // does the type implement GobEncoder?
+ isGobDecoder bool // does the type implement GobDecoder?
+ encIndir int8 // number of indirections to reach the receiver type; may be negative
+ decIndir int8 // number of indirections to reach the receiver type; may be negative
+}
+
+var (
+ // Protected by an RWMutex because we read it a lot and write
+ // it only when we see a new type, typically when compiling.
+ userTypeLock sync.RWMutex
+ userTypeCache = make(map[reflect.Type]*userTypeInfo)
+)
+
+// validType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, err will be non-nil. To be used when the error handler
+// is not set up.
+func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
+ userTypeLock.RLock()
+ ut = userTypeCache[rt]
+ userTypeLock.RUnlock()
+ if ut != nil {
+ return
+ }
+ // Now set the value under the write lock.
+ userTypeLock.Lock()
+ defer userTypeLock.Unlock()
+ if ut = userTypeCache[rt]; ut != nil {
+ // Lost the race; not a problem.
+ return
+ }
+ ut = new(userTypeInfo)
+ ut.base = rt
+ ut.user = rt
+ // A type that is just a cycle of pointers (such as type T *T) cannot
+ // be represented in gobs, which need some concrete data. We use a
+ // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
+ // pp 539-540. As we step through indirections, run another type at
+ // half speed. If they meet up, there's a cycle.
+ slowpoke := ut.base // walks half as fast as ut.base
+ for {
+ pt := ut.base
+ if pt.Kind() != reflect.Ptr {
+ break
+ }
+ ut.base = pt.Elem()
+ if ut.base == slowpoke { // ut.base lapped slowpoke
+ // recursive pointer type.
+ return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String())
+ }
+ if ut.indir%2 == 0 {
+ slowpoke = slowpoke.Elem()
+ }
+ ut.indir++
+ }
+ ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck)
+ ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck)
+ userTypeCache[rt] = ut
+ return
+}
+
+const (
+ gobEncodeMethodName = "GobEncode"
+ gobDecodeMethodName = "GobDecode"
+)
+
+// implements returns whether the type implements the interface, as encoded
+// in the check function.
+func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool {
+ if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance
+ return false
+ }
+ return check(typ)
+}
+
+// gobEncoderCheck makes the type assertion a boolean function.
+func gobEncoderCheck(typ reflect.Type) bool {
+ _, ok := reflect.Zero(typ).Interface().(GobEncoder)
+ return ok
+}
+
+// gobDecoderCheck makes the type assertion a boolean function.
+func gobDecoderCheck(typ reflect.Type) bool {
+ _, ok := reflect.Zero(typ).Interface().(GobDecoder)
+ return ok
+}
+
+// implementsInterface reports whether the type implements the
+// interface. (The actual check is done through the provided function.)
+// It also returns the number of indirections required to get to the
+// implementation.
+func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) {
+ if typ == nil {
+ return
+ }
+ rt := typ
+ // The type might be a pointer and we need to keep
+ // dereferencing to the base type until we find an implementation.
+ for {
+ if implements(rt, check) {
+ return true, indir
+ }
+ if p := rt; p.Kind() == reflect.Ptr {
+ indir++
+ if indir > 100 { // insane number of indirections
+ return false, 0
+ }
+ rt = p.Elem()
+ continue
+ }
+ break
+ }
+ // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy.
+ if typ.Kind() != reflect.Ptr {
+ // Not a pointer, but does the pointer work?
+ if implements(reflect.PtrTo(typ), check) {
+ return true, -1
+ }
+ }
+ return false, 0
+}
+
+// userType returns, and saves, the information associated with user-provided type rt.
+// If the user type is not valid, it calls error.
+func userType(rt reflect.Type) *userTypeInfo {
+ ut, err := validUserType(rt)
+ if err != nil {
+ error(err)
+ }
+ return ut
+}
+// A typeId represents a gob Type as an integer that can be passed on the wire.
+// Internally, typeIds are used as keys to a map to recover the underlying type info.
+type typeId int32
+
+var nextId typeId // incremented for each new type we build
+var typeLock sync.Mutex // set while building a type
+const firstUserId = 64 // lowest id number granted to user
+
+type gobType interface {
+ id() typeId
+ setId(id typeId)
+ name() string
+ string() string // not public; only for debugging
+ safeString(seen map[typeId]bool) string
+}
+
+var types = make(map[reflect.Type]gobType)
+var idToType = make(map[typeId]gobType)
+var builtinIdToType map[typeId]gobType // set in init() after builtins are established
+
+func setTypeId(typ gobType) {
+ nextId++
+ typ.setId(nextId)
+ idToType[nextId] = typ
+}
+
+func (t typeId) gobType() gobType {
+ if t == 0 {
+ return nil
+ }
+ return idToType[t]
+}
+
+// string returns the string representation of the type associated with the typeId.
+func (t typeId) string() string {
+ if t.gobType() == nil {
+ return "<nil>"
+ }
+ return t.gobType().string()
+}
+
+// Name returns the name of the type associated with the typeId.
+func (t typeId) name() string {
+ if t.gobType() == nil {
+ return "<nil>"
+ }
+ return t.gobType().name()
+}
+
+// Common elements of all types.
+type CommonType struct {
+ Name string
+ Id typeId
+}
+
+func (t *CommonType) id() typeId { return t.Id }
+
+func (t *CommonType) setId(id typeId) { t.Id = id }
+
+func (t *CommonType) string() string { return t.Name }
+
+func (t *CommonType) safeString(seen map[typeId]bool) string {
+ return t.Name
+}
+
+func (t *CommonType) name() string { return t.Name }
+
+// Create and check predefined types
+// The string for tBytes is "bytes" not "[]byte" to signify its specialness.
+
+var (
+ // Primordial types, needed during initialization.
+ // Always passed as pointers so the interface{} type
+ // goes through without losing its interfaceness.
+ tBool = bootstrapType("bool", (*bool)(nil), 1)
+ tInt = bootstrapType("int", (*int)(nil), 2)
+ tUint = bootstrapType("uint", (*uint)(nil), 3)
+ tFloat = bootstrapType("float", (*float64)(nil), 4)
+ tBytes = bootstrapType("bytes", (*[]byte)(nil), 5)
+ tString = bootstrapType("string", (*string)(nil), 6)
+ tComplex = bootstrapType("complex", (*complex128)(nil), 7)
+ tInterface = bootstrapType("interface", (*interface{})(nil), 8)
+ // Reserve some Ids for compatible expansion
+ tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil), 9)
+ tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil), 10)
+ tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil), 11)
+ tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil), 12)
+ tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil), 13)
+ tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil), 14)
+ tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil), 15)
+)
+
+// Predefined because it's needed by the Decoder
+var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id
+var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType)
+
+func init() {
+ // Some magic numbers to make sure there are no surprises.
+ checkId(16, tWireType)
+ checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id)
+ checkId(18, mustGetTypeInfo(reflect.Typeof(CommonType{})).id)
+ checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id)
+ checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id)
+ checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id)
+ checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id)
+
+ builtinIdToType = make(map[typeId]gobType)
+ for k, v := range idToType {
+ builtinIdToType[k] = v
+ }
+
+ // Move the id space upwards to allow for growth in the predefined world
+ // without breaking existing files.
+ if nextId > firstUserId {
+ panic(fmt.Sprintln("nextId too large:", nextId))
+ }
+ nextId = firstUserId
+ registerBasics()
+ wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil)))
+}
+
+// Array type
+type arrayType struct {
+ CommonType
+ Elem typeId
+ Len int
+}
+
+func newArrayType(name string) *arrayType {
+ a := &arrayType{CommonType{Name: name}, 0, 0}
+ return a
+}
+
+func (a *arrayType) init(elem gobType, len int) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(a)
+ a.Elem = elem.id()
+ a.Len = len
+}
+
+func (a *arrayType) safeString(seen map[typeId]bool) string {
+ if seen[a.Id] {
+ return a.Name
+ }
+ seen[a.Id] = true
+ return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen))
+}
+
+func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
+
+// GobEncoder type (something that implements the GobEncoder interface)
+type gobEncoderType struct {
+ CommonType
+}
+
+func newGobEncoderType(name string) *gobEncoderType {
+ g := &gobEncoderType{CommonType{Name: name}}
+ setTypeId(g)
+ return g
+}
+
+func (g *gobEncoderType) safeString(seen map[typeId]bool) string {
+ return g.Name
+}
+
+func (g *gobEncoderType) string() string { return g.Name }
+
+// Map type
+type mapType struct {
+ CommonType
+ Key typeId
+ Elem typeId
+}
+
+func newMapType(name string) *mapType {
+ m := &mapType{CommonType{Name: name}, 0, 0}
+ return m
+}
+
+func (m *mapType) init(key, elem gobType) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(m)
+ m.Key = key.id()
+ m.Elem = elem.id()
+}
+
+func (m *mapType) safeString(seen map[typeId]bool) string {
+ if seen[m.Id] {
+ return m.Name
+ }
+ seen[m.Id] = true
+ key := m.Key.gobType().safeString(seen)
+ elem := m.Elem.gobType().safeString(seen)
+ return fmt.Sprintf("map[%s]%s", key, elem)
+}
+
+func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) }
+
+// Slice type
+type sliceType struct {
+ CommonType
+ Elem typeId
+}
+
+func newSliceType(name string) *sliceType {
+ s := &sliceType{CommonType{Name: name}, 0}
+ return s
+}
+
+func (s *sliceType) init(elem gobType) {
+ // Set our type id before evaluating the element's, in case it's our own.
+ setTypeId(s)
+ s.Elem = elem.id()
+}
+
+func (s *sliceType) safeString(seen map[typeId]bool) string {
+ if seen[s.Id] {
+ return s.Name
+ }
+ seen[s.Id] = true
+ return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen))
+}
+
+func (s *sliceType) string() string { return s.safeString(make(map[typeId]bool)) }
+
+// Struct type
+type fieldType struct {
+ Name string
+ Id typeId
+}
+
+type structType struct {
+ CommonType
+ Field []*fieldType
+}
+
+func (s *structType) safeString(seen map[typeId]bool) string {
+ if s == nil {
+ return "<nil>"
+ }
+ if _, ok := seen[s.Id]; ok {
+ return s.Name
+ }
+ seen[s.Id] = true
+ str := s.Name + " = struct { "
+ for _, f := range s.Field {
+ str += fmt.Sprintf("%s %s; ", f.Name, f.Id.gobType().safeString(seen))
+ }
+ str += "}"
+ return str
+}
+
+func (s *structType) string() string { return s.safeString(make(map[typeId]bool)) }
+
+func newStructType(name string) *structType {
+ s := &structType{CommonType{Name: name}, nil}
+ // For historical reasons we set the id here rather than init.
+ // See the comment in newTypeObject for details.
+ setTypeId(s)
+ return s
+}
+
+// newTypeObject allocates a gobType for the reflection type rt.
+// Unless ut represents a GobEncoder, rt should be the base type
+// of ut.
+// This is only called from the encoding side. The decoding side
+// works through typeIds and userTypeInfos alone.
+func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+ // Does this type implement GobEncoder?
+ if ut.isGobEncoder {
+ return newGobEncoderType(name), nil
+ }
+ var err os.Error
+ var type0, type1 gobType
+ defer func() {
+ if err != nil {
+ types[rt] = nil, false
+ }
+ }()
+ // Install the top-level type before the subtypes (e.g. struct before
+ // fields) so recursive types can be constructed safely.
+ switch t := rt; t.Kind() {
+ // All basic types are easy: they are predefined.
+ case reflect.Bool:
+ return tBool.gobType(), nil
+
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return tInt.gobType(), nil
+
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return tUint.gobType(), nil
+
+ case reflect.Float32, reflect.Float64:
+ return tFloat.gobType(), nil
+
+ case reflect.Complex64, reflect.Complex128:
+ return tComplex.gobType(), nil
+
+ case reflect.String:
+ return tString.gobType(), nil
+
+ case reflect.Interface:
+ return tInterface.gobType(), nil
+
+ case reflect.Array:
+ at := newArrayType(name)
+ types[rt] = at
+ type0, err = getBaseType("", t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ // Historical aside:
+ // For arrays, maps, and slices, we set the type id after the elements
+ // are constructed. This is to retain the order of type id allocation after
+ // a fix made to handle recursive types, which changed the order in
+ // which types are built. Delaying the setting in this way preserves
+ // type ids while allowing recursive types to be described. Structs,
+ // done below, were already handling recursion correctly so they
+ // assign the top-level id before those of the field.
+ at.init(type0, t.Len())
+ return at, nil
+
+ case reflect.Map:
+ mt := newMapType(name)
+ types[rt] = mt
+ type0, err = getBaseType("", t.Key())
+ if err != nil {
+ return nil, err
+ }
+ type1, err = getBaseType("", t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ mt.init(type0, type1)
+ return mt, nil
+
+ case reflect.Slice:
+ // []byte == []uint8 is a special case
+ if t.Elem().Kind() == reflect.Uint8 {
+ return tBytes.gobType(), nil
+ }
+ st := newSliceType(name)
+ types[rt] = st
+ type0, err = getBaseType(t.Elem().Name(), t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ st.init(type0)
+ return st, nil
+
+ case reflect.Struct:
+ st := newStructType(name)
+ types[rt] = st
+ idToType[st.id()] = st
+ for i := 0; i < t.NumField(); i++ {
+ f := t.Field(i)
+ if !isExported(f.Name) {
+ continue
+ }
+ typ := userType(f.Type).base
+ tname := typ.Name()
+ if tname == "" {
+ t := userType(f.Type).base
+ tname = t.String()
+ }
+ gt, err := getBaseType(tname, f.Type)
+ if err != nil {
+ return nil, err
+ }
+ st.Field = append(st.Field, &fieldType{f.Name, gt.id()})
+ }
+ return st, nil
+
+ default:
+ return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String())
+ }
+ return nil, nil
+}
+
+// isExported reports whether this is an exported - upper case - name.
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// getBaseType returns the Gob type describing the given reflect.Type's base type.
+// typeLock must be held.
+func getBaseType(name string, rt reflect.Type) (gobType, os.Error) {
+ ut := userType(rt)
+ return getType(name, ut, ut.base)
+}
+
+// getType returns the Gob type describing the given reflect.Type.
+// Should be called only when handling GobEncoders/Decoders,
+// which may be pointers. All other types are handled through the
+// base type, never a pointer.
+// typeLock must be held.
+func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+ typ, present := types[rt]
+ if present {
+ return typ, nil
+ }
+ typ, err := newTypeObject(name, ut, rt)
+ if err == nil {
+ types[rt] = typ
+ }
+ return typ, err
+}
+
+func checkId(want, got typeId) {
+ if want != got {
+ fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(got), int(want))
+ panic("bootstrap type wrong id: " + got.name() + " " + got.string() + " not " + want.string())
+ }
+}
+
+// used for building the basic types; called only from init(). the incoming
+// interface always refers to a pointer.
+func bootstrapType(name string, e interface{}, expect typeId) typeId {
+ rt := reflect.Typeof(e).Elem()
+ _, present := types[rt]
+ if present {
+ panic("bootstrap type already present: " + name + ", " + rt.String())
+ }
+ typ := &CommonType{Name: name}
+ types[rt] = typ
+ setTypeId(typ)
+ checkId(expect, nextId)
+ userType(rt) // might as well cache it now
+ return nextId
+}
+
+// Representation of the information we send and receive about this type.
+// Each value we send is preceded by its type definition: an encoded int.
+// However, the very first time we send the value, we first send the pair
+// (-id, wireType).
+// For bootstrapping purposes, we assume that the recipient knows how
+// to decode a wireType; it is exactly the wireType struct here, interpreted
+// using the gob rules for sending a structure, except that we assume the
+// ids for wireType and structType etc. are known. The relevant pieces
+// are built in encode.go's init() function.
+// To maintain binary compatibility, if you extend this type, always put
+// the new fields last.
+type wireType struct {
+ ArrayT *arrayType
+ SliceT *sliceType
+ StructT *structType
+ MapT *mapType
+ GobEncoderT *gobEncoderType
+}
+
+func (w *wireType) string() string {
+ const unknown = "unknown type"
+ if w == nil {
+ return unknown
+ }
+ switch {
+ case w.ArrayT != nil:
+ return w.ArrayT.Name
+ case w.SliceT != nil:
+ return w.SliceT.Name
+ case w.StructT != nil:
+ return w.StructT.Name
+ case w.MapT != nil:
+ return w.MapT.Name
+ case w.GobEncoderT != nil:
+ return w.GobEncoderT.Name
+ }
+ return unknown
+}
+
+type typeInfo struct {
+ id typeId
+ encoder *encEngine
+ wire *wireType
+}
+
+var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock
+
+// typeLock must be held.
+func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) {
+ rt := ut.base
+ if ut.isGobEncoder {
+ // We want the user type, not the base type.
+ rt = ut.user
+ }
+ info, ok := typeInfoMap[rt]
+ if ok {
+ return info, nil
+ }
+ info = new(typeInfo)
+ gt, err := getBaseType(rt.Name(), rt)
+ if err != nil {
+ return nil, err
+ }
+ info.id = gt.id()
+
+ if ut.isGobEncoder {
+ userType, err := getType(rt.Name(), ut, rt)
+ if err != nil {
+ return nil, err
+ }
+ info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)}
+ typeInfoMap[ut.user] = info
+ return info, nil
+ }
+
+ t := info.id.gobType()
+ switch typ := rt; typ.Kind() {
+ case reflect.Array:
+ info.wire = &wireType{ArrayT: t.(*arrayType)}
+ case reflect.Map:
+ info.wire = &wireType{MapT: t.(*mapType)}
+ case reflect.Slice:
+ // []byte == []uint8 is a special case handled separately
+ if typ.Elem().Kind() != reflect.Uint8 {
+ info.wire = &wireType{SliceT: t.(*sliceType)}
+ }
+ case reflect.Struct:
+ info.wire = &wireType{StructT: t.(*structType)}
+ }
+ typeInfoMap[rt] = info
+ return info, nil
+}
+
+// Called only when a panic is acceptable and unexpected.
+func mustGetTypeInfo(rt reflect.Type) *typeInfo {
+ t, err := getTypeInfo(userType(rt))
+ if err != nil {
+ panic("getTypeInfo: " + err.String())
+ }
+ return t
+}
+
+// GobEncoder is the interface describing data that provides its own
+// representation for encoding values for transmission to a GobDecoder.
+// A type that implements GobEncoder and GobDecoder has complete
+// control over the representation of its data and may therefore
+// contain things such as private fields, channels, and functions,
+// which are not usually transmissable in gob streams.
+//
+// Note: Since gobs can be stored permanently, It is good design
+// to guarantee the encoding used by a GobEncoder is stable as the
+// software evolves. For instance, it might make sense for GobEncode
+// to include a version number in the encoding.
+type GobEncoder interface {
+ // GobEncode returns a byte slice representing the encoding of the
+ // receiver for transmission to a GobDecoder, usually of the same
+ // concrete type.
+ GobEncode() ([]byte, os.Error)
+}
+
+// GobDecoder is the interface describing data that provides its own
+// routine for decoding transmitted values sent by a GobEncoder.
+type GobDecoder interface {
+ // GobDecode overwrites the receiver, which must be a pointer,
+ // with the value represented by the byte slice, which was written
+ // by GobEncode, usually for the same concrete type.
+ GobDecode([]byte) os.Error
+}
+
+var (
+ nameToConcreteType = make(map[string]reflect.Type)
+ concreteTypeToName = make(map[reflect.Type]string)
+)
+
+// RegisterName is like Register but uses the provided name rather than the
+// type's default.
+func RegisterName(name string, value interface{}) {
+ if name == "" {
+ // reserved for nil
+ panic("attempt to register empty name")
+ }
+ base := userType(reflect.Typeof(value)).base
+ // Check for incompatible duplicates.
+ if t, ok := nameToConcreteType[name]; ok && t != base {
+ panic("gob: registering duplicate types for " + name)
+ }
+ if n, ok := concreteTypeToName[base]; ok && n != name {
+ panic("gob: registering duplicate names for " + base.String())
+ }
+ // Store the name and type provided by the user....
+ nameToConcreteType[name] = reflect.Typeof(value)
+ // but the flattened type in the type table, since that's what decode needs.
+ concreteTypeToName[base] = name
+}
+
+// Register records a type, identified by a value for that type, under its
+// internal type name. That name will identify the concrete type of a value
+// sent or received as an interface variable. Only types that will be
+// transferred as implementations of interface values need to be registered.
+// Expecting to be used only during initialization, it panics if the mapping
+// between types and names is not a bijection.
+func Register(value interface{}) {
+ // Default to printed representation for unnamed types
+ rt := reflect.Typeof(value)
+ name := rt.String()
+
+ // But for named types (or pointers to them), qualify with import path.
+ // Dereference one pointer looking for a named type.
+ star := ""
+ if rt.Name() == "" {
+ if pt := rt; pt.Kind() == reflect.Ptr {
+ star = "*"
+ rt = pt
+ }
+ }
+ if rt.Name() != "" {
+ if rt.PkgPath() == "" {
+ name = star + rt.Name()
+ } else {
+ name = star + rt.PkgPath() + "." + rt.Name()
+ }
+ }
+
+ RegisterName(name, value)
+}
+
+func registerBasics() {
+ Register(int(0))
+ Register(int8(0))
+ Register(int16(0))
+ Register(int32(0))
+ Register(int64(0))
+ Register(uint(0))
+ Register(uint8(0))
+ Register(uint16(0))
+ Register(uint32(0))
+ Register(uint64(0))
+ Register(float32(0))
+ Register(float64(0))
+ Register(complex64(0i))
+ Register(complex128(0i))
+ Register(false)
+ Register("")
+ Register([]byte(nil))
+}