]> Cypherpunks repositories - gostls13.git/commitdiff
go/format: Package format implements standard formatting of Go source.
authorRobert Griesemer <gri@golang.org>
Tue, 27 Nov 2012 18:29:49 +0000 (10:29 -0800)
committerRobert Griesemer <gri@golang.org>
Tue, 27 Nov 2012 18:29:49 +0000 (10:29 -0800)
Package format is a utility package that takes care of
parsing, sorting of imports, and formatting of .go source
using the canonical gofmt formatting parameters.

Use go/format in various clients instead of the lower-level components.

R=r, bradfitz, dave, rogpeppe, rsc
CC=golang-dev
https://golang.org/cl/6852075

src/cmd/fix/main.go
src/cmd/godoc/godoc.go
src/cmd/godoc/play.go
src/cmd/godoc/template.go
src/pkg/go/format/format.go [new file with mode: 0644]
src/pkg/go/format/format_test.go [new file with mode: 0644]

index b151408d74c6907766a579bc0ca2cd508f85eda1..dc10d6beb526d892d3dec26c20063b3732208462 100644 (file)
@@ -9,8 +9,8 @@ import (
        "flag"
        "fmt"
        "go/ast"
+       "go/format"
        "go/parser"
-       "go/printer"
        "go/scanner"
        "go/token"
        "io/ioutil"
@@ -97,23 +97,11 @@ func main() {
        os.Exit(exitCode)
 }
 
-const (
-       tabWidth    = 8
-       parserMode  = parser.ParseComments
-       printerMode = printer.TabIndent | printer.UseSpaces
-)
-
-var printConfig = &printer.Config{
-       Mode:     printerMode,
-       Tabwidth: tabWidth,
-}
+const parserMode = parser.ParseComments
 
 func gofmtFile(f *ast.File) ([]byte, error) {
        var buf bytes.Buffer
-
-       ast.SortImports(fset, f)
-       err := printConfig.Fprint(&buf, fset, f)
-       if err != nil {
+       if err := format.Node(&buf, fset, f); err != nil {
                return nil, err
        }
        return buf.Bytes(), nil
@@ -211,8 +199,7 @@ var gofmtBuf bytes.Buffer
 
 func gofmt(n interface{}) string {
        gofmtBuf.Reset()
-       err := printConfig.Fprint(&gofmtBuf, fset, n)
-       if err != nil {
+       if err := format.Node(&gofmtBuf, fset, n); err != nil {
                return "<" + err.Error() + ">"
        }
        return gofmtBuf.String()
index 57ef9f37782ef391afa49f8a1ac6b415425d0b89..9ac38c746e851e7583f36b1601737de7612f8e25 100644 (file)
@@ -12,6 +12,7 @@ import (
        "go/ast"
        "go/build"
        "go/doc"
+       "go/format"
        "go/printer"
        "go/token"
        "io"
@@ -356,10 +357,8 @@ func example_htmlFunc(funcName string, examples []*doc.Example, fset *token.File
                // (use tabs, no comment highlight, etc).
                play := ""
                if eg.Play != nil && *showPlayground {
-                       ast.SortImports(fset, eg.Play)
                        var buf bytes.Buffer
-                       err := (&printer.Config{Mode: printer.TabIndent, Tabwidth: 8}).Fprint(&buf, fset, eg.Play)
-                       if err != nil {
+                       if err := format.Node(&buf, fset, eg.Play); err != nil {
                                log.Print(err)
                        } else {
                                play = buf.String()
index 7033169c83fb4230c368e6fbe4ce8b866de74b01..47a11f6c0bf8ef0a86eb3b57a544296c235bfb8e 100644 (file)
@@ -7,13 +7,9 @@
 package main
 
 import (
-       "bytes"
        "encoding/json"
        "fmt"
-       "go/ast"
-       "go/parser"
-       "go/printer"
-       "go/token"
+       "go/format"
        "net/http"
 )
 
@@ -40,36 +36,15 @@ type fmtResponse struct {
 // standard gofmt formatting, and writes a fmtResponse as a JSON object.
 func fmtHandler(w http.ResponseWriter, r *http.Request) {
        resp := new(fmtResponse)
-       body, err := gofmt(r.FormValue("body"))
+       body, err := format.Source([]byte(r.FormValue("body")))
        if err != nil {
                resp.Error = err.Error()
        } else {
-               resp.Body = body
+               resp.Body = string(body)
        }
        json.NewEncoder(w).Encode(resp)
 }
 
-// gofmt takes a Go program, formats it using the standard Go formatting
-// rules, and returns it or an error.
-func gofmt(body string) (string, error) {
-       fset := token.NewFileSet()
-       f, err := parser.ParseFile(fset, "prog.go", body, parser.ParseComments)
-       if err != nil {
-               return "", err
-       }
-       ast.SortImports(fset, f)
-       var buf bytes.Buffer
-       config := printer.Config{
-               Mode:     printer.UseSpaces | printer.TabIndent,
-               Tabwidth: 8,
-       }
-       err = config.Fprint(&buf, fset, f)
-       if err != nil {
-               return "", err
-       }
-       return buf.String(), nil
-}
-
 // disabledHandler serves a 501 "Not Implemented" response.
 func disabledHandler(w http.ResponseWriter, r *http.Request) {
        w.WriteHeader(http.StatusNotImplemented)
index c96bf5bc4e1b7c68b27420c1259300c58ca902ff..7b9b9cfeb02ea31b004586be36299a88324f112d 100644 (file)
@@ -57,8 +57,8 @@ func contents(name string) string {
        return string(file)
 }
 
-// format returns a textual representation of the arg, formatted according to its nature.
-func format(arg interface{}) string {
+// stringFor returns a textual representation of the arg, formatted according to its nature.
+func stringFor(arg interface{}) string {
        switch arg := arg.(type) {
        case int:
                return fmt.Sprintf("%d", arg)
@@ -87,10 +87,10 @@ func code(file string, arg ...interface{}) (s string, err error) {
                // text is already whole file.
                command = fmt.Sprintf("code %q", file)
        case 1:
-               command = fmt.Sprintf("code %q %s", file, format(arg[0]))
+               command = fmt.Sprintf("code %q %s", file, stringFor(arg[0]))
                text = oneLine(file, text, arg[0])
        case 2:
-               command = fmt.Sprintf("code %q %s %s", file, format(arg[0]), format(arg[1]))
+               command = fmt.Sprintf("code %q %s %s", file, stringFor(arg[0]), stringFor(arg[1]))
                text = multipleLines(file, text, arg[0], arg[1])
        default:
                return "", fmt.Errorf("incorrect code invocation: code %q %q", file, arg)
diff --git a/src/pkg/go/format/format.go b/src/pkg/go/format/format.go
new file mode 100644 (file)
index 0000000..286296e
--- /dev/null
@@ -0,0 +1,200 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package format implements standard formatting of Go source.
+package format
+
+import (
+       "bytes"
+       "fmt"
+       "go/ast"
+       "go/parser"
+       "go/printer"
+       "go/token"
+       "io"
+       "strings"
+)
+
+var config = printer.Config{Mode: printer.UseSpaces | printer.TabIndent, Tabwidth: 8}
+
+// Node formats node in canonical gofmt style and writes the result to dst.
+//
+// The node type must be *ast.File, *printer.CommentedNode, []ast.Decl,
+// []ast.Stmt, or assignment-compatible to ast.Expr, ast.Decl, ast.Spec,
+// or ast.Stmt. Node does not modify node. Imports are not sorted for
+// nodes representing partial source files (i.e., if the node is not an
+// *ast.File or a *printer.CommentedNode not wrapping an *ast.File).
+//
+// The function may return early (before the entire result is written)
+// and return a formatting error, for instance due to an incorrect AST.
+//
+func Node(dst io.Writer, fset *token.FileSet, node interface{}) error {
+       // Determine if we have a complete source file (file != nil).
+       var file *ast.File
+       var cnode *printer.CommentedNode
+       switch n := node.(type) {
+       case *ast.File:
+               file = n
+       case *printer.CommentedNode:
+               if f, ok := n.Node.(*ast.File); ok {
+                       file = f
+                       cnode = n
+               }
+       }
+
+       // Sort imports if necessary.
+       if file != nil && hasUnsortedImports(file) {
+               // Make a copy of the AST because ast.SortImports is destructive.
+               // TODO(gri) Do this more efficently.
+               var buf bytes.Buffer
+               err := config.Fprint(&buf, fset, file)
+               if err != nil {
+                       return err
+               }
+               file, err = parser.ParseFile(fset, "", buf.Bytes(), parser.ParseComments)
+               if err != nil {
+                       // We should never get here. If we do, provide good diagnostic.
+                       return fmt.Errorf("format.Node internal error (%s)", err)
+               }
+               ast.SortImports(fset, file)
+
+               // Use new file with sorted imports.
+               node = file
+               if cnode != nil {
+                       node = &printer.CommentedNode{Node: file, Comments: cnode.Comments}
+               }
+       }
+
+       return config.Fprint(dst, fset, node)
+}
+
+// Source formats src in canonical gofmt style and writes the result to dst
+// or returns an I/O or syntax error. src is expected to be a syntactically
+// correct Go source file, or a list of Go declarations or statements.
+//
+// If src is a partial source file, the leading and trailing space of src
+// is applied to the result (such that it has the same leading and trailing
+// space as src), and the formatted src is indented by the same amount as
+// the first line of src containing code. Imports are not sorted for partial
+// source files.
+//
+func Source(src []byte) ([]byte, error) {
+       fset := token.NewFileSet()
+       node, err := parse(fset, src)
+       if err != nil {
+               return nil, err
+       }
+
+       var buf bytes.Buffer
+       if file, ok := node.(*ast.File); ok {
+               // Complete source file.
+               ast.SortImports(fset, file)
+               err := config.Fprint(&buf, fset, file)
+               if err != nil {
+                       return nil, err
+               }
+
+       } else {
+               // Partial source file.
+               // Determine and prepend leading space.
+               i, j := 0, 0
+               for j < len(src) && isSpace(src[j]) {
+                       if src[j] == '\n' {
+                               i = j + 1 // index of last line in leading space
+                       }
+                       j++
+               }
+               buf.Write(src[:i])
+
+               // Determine indentation of first code line.
+               // Spaces are ignored unless there are no tabs,
+               // in which case spaces count as one tab.
+               indent := 0
+               hasSpace := false
+               for _, b := range src[i:j] {
+                       switch b {
+                       case ' ':
+                               hasSpace = true
+                       case '\t':
+                               indent++
+                       }
+               }
+               if indent == 0 && hasSpace {
+                       indent = 1
+               }
+
+               // Format the source.
+               cfg := config
+               cfg.Indent = indent
+               err := cfg.Fprint(&buf, fset, node)
+               if err != nil {
+                       return nil, err
+               }
+
+               // Determine and append trailing space.
+               i = len(src)
+               for i > 0 && isSpace(src[i-1]) {
+                       i--
+               }
+               buf.Write(src[i:])
+       }
+
+       return buf.Bytes(), nil
+}
+
+func hasUnsortedImports(file *ast.File) bool {
+       for _, d := range file.Decls {
+               d, ok := d.(*ast.GenDecl)
+               if !ok || d.Tok != token.IMPORT {
+                       // Not an import declaration, so we're done.
+                       // Imports are always first.
+                       return false
+               }
+               if d.Lparen.IsValid() {
+                       // For now assume all grouped imports are unsorted.
+                       // TODO(gri) Should check if they are sorted already.
+                       return true
+               }
+               // Ungrouped imports are sorted by default.
+       }
+       return false
+}
+
+func isSpace(b byte) bool {
+       return b == ' ' || b == '\t' || b == '\n' || b == '\r'
+}
+
+func parse(fset *token.FileSet, src []byte) (interface{}, error) {
+       // Try as a complete source file.
+       file, err := parser.ParseFile(fset, "", src, parser.ParseComments)
+       if err == nil {
+               return file, nil
+       }
+       // If the source is missing a package clause, try as a source fragment; otherwise fail.
+       if !strings.Contains(err.Error(), "expected 'package'") {
+               return nil, err
+       }
+
+       // Try as a declaration list by prepending a package clause in front of src.
+       // Use ';' not '\n' to keep line numbers intact.
+       psrc := append([]byte("package p;"), src...)
+       file, err = parser.ParseFile(fset, "", psrc, parser.ParseComments)
+       if err == nil {
+               return file.Decls, nil
+       }
+       // If the source is missing a declaration, try as a statement list; otherwise fail.
+       if !strings.Contains(err.Error(), "expected declaration") {
+               return nil, err
+       }
+
+       // Try as statement list by wrapping a function around src.
+       fsrc := append(append([]byte("package p; func _() {"), src...), '}')
+       file, err = parser.ParseFile(fset, "", fsrc, parser.ParseComments)
+       if err == nil {
+               return file.Decls[0].(*ast.FuncDecl).Body.List, nil
+       }
+
+       // Failed, and out of options.
+       return nil, err
+}
diff --git a/src/pkg/go/format/format_test.go b/src/pkg/go/format/format_test.go
new file mode 100644 (file)
index 0000000..7d7940b
--- /dev/null
@@ -0,0 +1,125 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package format
+
+import (
+       "bytes"
+       "go/parser"
+       "go/token"
+       "io/ioutil"
+       "strings"
+       "testing"
+)
+
+const testfile = "format_test.go"
+
+func diff(t *testing.T, dst, src []byte) {
+       line := 1
+       offs := 0 // line offset
+       for i := 0; i < len(dst) && i < len(src); i++ {
+               d := dst[i]
+               s := src[i]
+               if d != s {
+                       t.Errorf("dst:%d: %s\n", line, dst[offs:i+1])
+                       t.Errorf("src:%d: %s\n", line, src[offs:i+1])
+                       return
+               }
+               if s == '\n' {
+                       line++
+                       offs = i + 1
+               }
+       }
+       if len(dst) != len(src) {
+               t.Errorf("len(dst) = %d, len(src) = %d\nsrc = %q", len(dst), len(src), src)
+       }
+}
+
+func TestNode(t *testing.T) {
+       src, err := ioutil.ReadFile(testfile)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       fset := token.NewFileSet()
+       file, err := parser.ParseFile(fset, testfile, src, parser.ParseComments)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       var buf bytes.Buffer
+
+       if err = Node(&buf, fset, file); err != nil {
+               t.Fatal("Node failed:", err)
+       }
+
+       diff(t, buf.Bytes(), src)
+}
+
+func TestSource(t *testing.T) {
+       src, err := ioutil.ReadFile(testfile)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       res, err := Source(src)
+       if err != nil {
+               t.Fatal("Source failed:", err)
+       }
+
+       diff(t, res, src)
+}
+
+// Test cases that are expected to fail are marked by the prefix "ERROR".
+var tests = []string{
+       // declaration lists
+       `import "go/format"`,
+       "var x int",
+       "var x int\n\ntype T struct{}",
+
+       // statement lists
+       "x := 0",
+       "f(a, b, c)\nvar x int = f(1, 2, 3)",
+
+       // indentation, leading and trailing space
+       "\tx := 0\n\tgo f()",
+       "\tx := 0\n\tgo f()\n\n\n",
+       "\n\t\t\n\n\tx := 0\n\tgo f()\n\n\n",
+       "\n\t\t\n\n\t\t\tx := 0\n\t\t\tgo f()\n\n\n",
+       "\n\t\t\n\n\t\t\tx := 0\n\t\t\tconst s = `\nfoo\n`\n\n\n", // no indentation inside raw strings
+
+       // erroneous programs
+       "ERRORvar x",
+       "ERROR1 + 2 +",
+       "ERRORx :=  0",
+}
+
+func String(s string) (string, error) {
+       res, err := Source([]byte(s))
+       if err != nil {
+               return "", err
+       }
+       return string(res), nil
+}
+
+func TestPartial(t *testing.T) {
+       for _, src := range tests {
+               if strings.HasPrefix(src, "ERROR") {
+                       // test expected to fail
+                       src = src[5:] // remove ERROR prefix
+                       res, err := String(src)
+                       if err == nil && res == src {
+                               t.Errorf("formatting succeeded but was expected to fail:\n%q", src)
+                       }
+               } else {
+                       // test expected to succeed
+                       res, err := String(src)
+                       if err != nil {
+                               t.Errorf("formatting failed (%s):\n%q", err, src)
+                       } else if res != src {
+                               t.Errorf("formatting incorrect:\nsource: %q\nresult: %q", src, res)
+                       }
+               }
+       }
+}