]> Cypherpunks repositories - gostls13.git/commitdiff
gofix: new command for updating code to new release
authorRuss Cox <rsc@golang.org>
Tue, 15 Mar 2011 18:15:41 +0000 (14:15 -0400)
committerRuss Cox <rsc@golang.org>
Tue, 15 Mar 2011 18:15:41 +0000 (14:15 -0400)
R=bradfitzgo, dsymonds, r, gri, adg
CC=golang-dev
https://golang.org/cl/4282044

src/cmd/Makefile
src/cmd/gofix/Makefile [new file with mode: 0644]
src/cmd/gofix/doc.go [new file with mode: 0644]
src/cmd/gofix/fix.go [new file with mode: 0644]
src/cmd/gofix/httpserver.go [new file with mode: 0644]
src/cmd/gofix/httpserver_test.go [new file with mode: 0644]
src/cmd/gofix/main.go [new file with mode: 0644]
src/cmd/gofix/main_test.go [new file with mode: 0644]
src/pkg/Makefile

index 779bd44c7932a14a1b7883b6d0bd6de18121bb25..fdb33f0702725f715e63b703099939ed9cc136d9 100644 (file)
@@ -41,6 +41,7 @@ CLEANDIRS=\
        cgo\
        ebnflint\
        godoc\
+       gofix\
        gofmt\
        goinstall\
        gotype\
diff --git a/src/cmd/gofix/Makefile b/src/cmd/gofix/Makefile
new file mode 100644 (file)
index 0000000..020a6a2
--- /dev/null
@@ -0,0 +1,16 @@
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../Make.inc
+
+TARG=gofix
+GOFILES=\
+       fix.go\
+       main.go\
+       httpserver.go\
+
+include ../../Make.cmd
+
+test:
+       gotest
diff --git a/src/cmd/gofix/doc.go b/src/cmd/gofix/doc.go
new file mode 100644 (file)
index 0000000..e267d5d
--- /dev/null
@@ -0,0 +1,34 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Gofix finds Go programs that use old APIs and rewrites them to use
+newer ones.  After you update to a new Go release, gofix helps make
+the necessary changes to your programs.
+
+Usage:
+       gofix [-r name,...] [path ...]
+
+Without an explicit path, gofix reads standard input and writes the
+result to standard output.
+
+If the named path is a file, gofix rewrites the named files in place.
+If the named path is a directory, gofix rewrites all .go files in that
+directory tree.  When gofix rewrites a file, it prints a line to standard
+error giving the name of the file and the rewrite applied.
+
+The -r flag restricts the set of rewrites considered to those in the
+named list.  By default gofix considers all known rewrites.  Gofix's
+rewrites are idempotent, so that it is safe to apply gofix to updated
+or partially updated code even without using the -r flag.
+
+Gofix prints the full list of fixes it can apply in its help output;
+to see them, run godoc -?.
+
+Gofix does not make backup copies of the files that it edits.
+Instead, use a version control system's ``diff'' functionality to inspect
+the changes that gofix makes before committing them.
+
+*/
+package documentation
diff --git a/src/cmd/gofix/fix.go b/src/cmd/gofix/fix.go
new file mode 100644 (file)
index 0000000..eadddba
--- /dev/null
@@ -0,0 +1,297 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "fmt"
+       "go/ast"
+       "go/token"
+       "os"
+)
+
+type fix struct {
+       name string
+       f    func(*ast.File) bool
+       desc string
+}
+
+// main runs sort.Sort(fixes) after init process is done.
+type fixlist []fix
+
+func (f fixlist) Len() int           { return len(f) }
+func (f fixlist) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
+func (f fixlist) Less(i, j int) bool { return f[i].name < f[j].name }
+
+var fixes fixlist
+
+func register(f fix) {
+       fixes = append(fixes, f)
+}
+
+// rewrite walks the AST x, calling visit(y) for each node y in the tree but
+// also with a pointer to each ast.Expr, in a bottom-up traversal.
+func rewrite(x interface{}, visit func(interface{})) {
+       switch n := x.(type) {
+       case *ast.Expr:
+               rewrite(*n, visit)
+
+       // everything else just recurses
+       default:
+               panic(fmt.Errorf("unexpected type %T in walk", x, visit))
+
+       case nil:
+
+       // These are ordered and grouped to match ../../pkg/go/ast/ast.go
+       case *ast.Field:
+               rewrite(&n.Type, visit)
+       case *ast.FieldList:
+               for _, field := range n.List {
+                       rewrite(field, visit)
+               }
+       case *ast.BadExpr:
+       case *ast.Ident:
+       case *ast.Ellipsis:
+       case *ast.BasicLit:
+       case *ast.FuncLit:
+               rewrite(n.Type, visit)
+               rewrite(n.Body, visit)
+       case *ast.CompositeLit:
+               rewrite(&n.Type, visit)
+               rewrite(n.Elts, visit)
+       case *ast.ParenExpr:
+               rewrite(&n.X, visit)
+       case *ast.SelectorExpr:
+               rewrite(&n.X, visit)
+       case *ast.IndexExpr:
+               rewrite(&n.X, visit)
+               rewrite(&n.Index, visit)
+       case *ast.SliceExpr:
+               rewrite(&n.X, visit)
+               if n.Low != nil {
+                       rewrite(&n.Low, visit)
+               }
+               if n.High != nil {
+                       rewrite(&n.High, visit)
+               }
+       case *ast.TypeAssertExpr:
+               rewrite(&n.X, visit)
+               rewrite(&n.Type, visit)
+       case *ast.CallExpr:
+               rewrite(&n.Fun, visit)
+               rewrite(n.Args, visit)
+       case *ast.StarExpr:
+               rewrite(&n.X, visit)
+       case *ast.UnaryExpr:
+               rewrite(&n.X, visit)
+       case *ast.BinaryExpr:
+               rewrite(&n.X, visit)
+               rewrite(&n.Y, visit)
+       case *ast.KeyValueExpr:
+               rewrite(&n.Key, visit)
+               rewrite(&n.Value, visit)
+
+       case *ast.ArrayType:
+               rewrite(&n.Len, visit)
+               rewrite(&n.Elt, visit)
+       case *ast.StructType:
+               rewrite(n.Fields, visit)
+       case *ast.FuncType:
+               rewrite(n.Params, visit)
+               if n.Results != nil {
+                       rewrite(n.Results, visit)
+               }
+       case *ast.InterfaceType:
+               rewrite(n.Methods, visit)
+       case *ast.MapType:
+               rewrite(&n.Key, visit)
+               rewrite(&n.Value, visit)
+       case *ast.ChanType:
+               rewrite(&n.Value, visit)
+
+       case *ast.BadStmt:
+       case *ast.DeclStmt:
+               rewrite(n.Decl, visit)
+       case *ast.EmptyStmt:
+       case *ast.LabeledStmt:
+               rewrite(n.Stmt, visit)
+       case *ast.ExprStmt:
+               rewrite(&n.X, visit)
+       case *ast.SendStmt:
+               rewrite(&n.Chan, visit)
+               rewrite(&n.Value, visit)
+       case *ast.IncDecStmt:
+               rewrite(&n.X, visit)
+       case *ast.AssignStmt:
+               rewrite(n.Lhs, visit)
+               if len(n.Lhs) == 2 && len(n.Rhs) == 1 {
+                       rewrite(n.Rhs, visit)
+               } else {
+                       rewrite(n.Rhs, visit)
+               }
+       case *ast.GoStmt:
+               rewrite(n.Call, visit)
+       case *ast.DeferStmt:
+               rewrite(n.Call, visit)
+       case *ast.ReturnStmt:
+               rewrite(n.Results, visit)
+       case *ast.BranchStmt:
+       case *ast.BlockStmt:
+               rewrite(n.List, visit)
+       case *ast.IfStmt:
+               rewrite(n.Init, visit)
+               rewrite(&n.Cond, visit)
+               rewrite(n.Body, visit)
+               rewrite(n.Else, visit)
+       case *ast.CaseClause:
+               rewrite(n.Values, visit)
+               rewrite(n.Body, visit)
+       case *ast.SwitchStmt:
+               rewrite(n.Init, visit)
+               rewrite(&n.Tag, visit)
+               rewrite(n.Body, visit)
+       case *ast.TypeCaseClause:
+               rewrite(n.Types, visit)
+               rewrite(n.Body, visit)
+       case *ast.TypeSwitchStmt:
+               rewrite(n.Init, visit)
+               rewrite(n.Assign, visit)
+               rewrite(n.Body, visit)
+       case *ast.CommClause:
+               rewrite(n.Comm, visit)
+               rewrite(n.Body, visit)
+       case *ast.SelectStmt:
+               rewrite(n.Body, visit)
+       case *ast.ForStmt:
+               rewrite(n.Init, visit)
+               rewrite(&n.Cond, visit)
+               rewrite(n.Post, visit)
+               rewrite(n.Body, visit)
+       case *ast.RangeStmt:
+               rewrite(&n.Key, visit)
+               rewrite(&n.Value, visit)
+               rewrite(&n.X, visit)
+               rewrite(n.Body, visit)
+
+       case *ast.ImportSpec:
+       case *ast.ValueSpec:
+               rewrite(&n.Type, visit)
+               rewrite(n.Values, visit)
+       case *ast.TypeSpec:
+               rewrite(&n.Type, visit)
+
+       case *ast.BadDecl:
+       case *ast.GenDecl:
+               rewrite(n.Specs, visit)
+       case *ast.FuncDecl:
+               if n.Recv != nil {
+                       rewrite(n.Recv, visit)
+               }
+               rewrite(n.Type, visit)
+               if n.Body != nil {
+                       rewrite(n.Body, visit)
+               }
+
+       case *ast.File:
+               rewrite(n.Decls, visit)
+
+       case *ast.Package:
+               for _, file := range n.Files {
+                       rewrite(file, visit)
+               }
+
+       case []ast.Decl:
+               for _, d := range n {
+                       rewrite(d, visit)
+               }
+       case []ast.Expr:
+               for i := range n {
+                       rewrite(&n[i], visit)
+               }
+       case []ast.Stmt:
+               for _, s := range n {
+                       rewrite(s, visit)
+               }
+       case []ast.Spec:
+               for _, s := range n {
+                       rewrite(s, visit)
+               }
+       }
+       visit(x)
+}
+
+func imports(f *ast.File, path string) bool {
+       for _, decl := range f.Decls {
+               d, ok := decl.(*ast.GenDecl)
+               if !ok {
+                       continue
+               }
+               for _, spec := range d.Specs {
+                       s, ok := spec.(*ast.ImportSpec)
+                       if !ok {
+                               continue
+                       }
+                       if string(s.Path.Value) == `"`+path+`"` {
+                               return true
+                       }
+               }
+       }
+       return false
+}
+
+func isPkgDot(t ast.Expr, pkg, name string) bool {
+       sel, ok := t.(*ast.SelectorExpr)
+       if !ok {
+               return false
+       }
+       return isName(sel.X, pkg) && sel.Sel.String() == name
+}
+
+func isPtrPkgDot(t ast.Expr, pkg, name string) bool {
+       ptr, ok := t.(*ast.StarExpr)
+       if !ok {
+               return false
+       }
+       return isPkgDot(ptr.X, pkg, name)
+}
+
+func isName(n ast.Expr, name string) bool {
+       id, ok := n.(*ast.Ident)
+       if !ok {
+               return false
+       }
+       return id.String() == name
+}
+
+func refersTo(n ast.Node, x *ast.Ident) bool {
+       id, ok := n.(*ast.Ident)
+       if !ok {
+               return false
+       }
+       return id.String() == x.String()
+}
+
+func isBlank(n ast.Expr) bool {
+       return isName(n, "_")
+}
+
+func isEmptyString(n ast.Expr) bool {
+       lit, ok := n.(*ast.BasicLit)
+       if !ok {
+               return false
+       }
+       if lit.Kind != token.STRING {
+               return false
+       }
+       s := string(lit.Value)
+       return s == `""` || s == "``"
+}
+
+func warn(pos token.Pos, msg string, args ...interface{}) {
+       s := ""
+       if pos.IsValid() {
+               s = fmt.Sprintf("%s: ", fset.Position(pos).String())
+       }
+       fmt.Fprintf(os.Stderr, "%s"+msg+"\n", append([]interface{}{s}, args...))
+}
diff --git a/src/cmd/gofix/httpserver.go b/src/cmd/gofix/httpserver.go
new file mode 100644 (file)
index 0000000..8899653
--- /dev/null
@@ -0,0 +1,125 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "go/ast"
+       "go/token"
+)
+
+var httpserverFix = fix{
+       "httpserver",
+       httpserver,
+`Adapt http server methods and functions to changes
+made to the http ResponseWriter interface.
+
+http://codereview.appspot.com/4245064  Hijacker
+http://codereview.appspot.com/4239076  Header
+http://codereview.appspot.com/4239077  Flusher
+http://codereview.appspot.com/4248075  RemoteAddr, UsingTLS
+`,
+}
+
+func init() {
+       register(httpserverFix)
+}
+
+func httpserver(f *ast.File) bool {
+       if !imports(f, "http") {
+               return false
+       }
+
+       fixed := false
+       for _, decl := range f.Decls {
+               fn, ok := decl.(*ast.FuncDecl)
+               if !ok {
+                       continue
+               }
+               w, req, ok := isServeHTTP(fn)
+               if !ok {
+                       continue
+               }
+               rewrite(fn.Body, func(n interface{}) {
+                       // Want to replace expression sometimes,
+                       // so record pointer to it for updating below.
+                       ptr, ok := n.(*ast.Expr)
+                       if ok {
+                               n = *ptr
+                       }
+
+                       // Look for w.UsingTLS() and w.Remoteaddr().
+                       call, ok := n.(*ast.CallExpr)
+                       if !ok || len(call.Args) != 0 {
+                               return
+                       }
+                       sel, ok := call.Fun.(*ast.SelectorExpr)
+                       if !ok {
+                               return
+                       }
+                       if !refersTo(sel.X, w) {
+                               return
+                       }
+                       switch sel.Sel.String() {
+                       case "Hijack":
+                               // replace w with w.(http.Hijacker)
+                               sel.X = &ast.TypeAssertExpr{
+                                       X:    sel.X,
+                                       Type: ast.NewIdent("http.Hijacker"),
+                               }
+                               fixed = true
+                       case "Flush":
+                               // replace w with w.(http.Flusher)
+                               sel.X = &ast.TypeAssertExpr{
+                                       X:    sel.X,
+                                       Type: ast.NewIdent("http.Flusher"),
+                               }
+                               fixed = true
+                       case "UsingTLS":
+                               if ptr == nil {
+                                       // can only replace expression if we have pointer to it
+                                       break
+                               }
+                               // replace with req.TLS != nil
+                               *ptr = &ast.BinaryExpr{
+                                       X: &ast.SelectorExpr{
+                                               X:   ast.NewIdent(req.String()),
+                                               Sel: ast.NewIdent("TLS"),
+                                       },
+                                       Op: token.NEQ,
+                                       Y:  ast.NewIdent("nil"),
+                               }
+                               fixed = true
+                       case "RemoteAddr":
+                               if ptr == nil {
+                                       // can only replace expression if we have pointer to it
+                                       break
+                               }
+                               // replace with req.RemoteAddr
+                               *ptr = &ast.SelectorExpr{
+                                       X:   ast.NewIdent(req.String()),
+                                       Sel: ast.NewIdent("RemoteAddr"),
+                               }
+                               fixed = true
+                       }
+               })
+       }
+       return fixed
+}
+
+func isServeHTTP(fn *ast.FuncDecl) (w, req *ast.Ident, ok bool) {
+       for _, field := range fn.Type.Params.List {
+               if isPkgDot(field.Type, "http", "ResponseWriter") {
+                       w = field.Names[0]
+                       continue
+               }
+               if isPtrPkgDot(field.Type, "http", "Request") {
+                       req = field.Names[0]
+                       continue
+               }
+       }
+
+       ok = w != nil && req != nil
+       return
+}
diff --git a/src/cmd/gofix/httpserver_test.go b/src/cmd/gofix/httpserver_test.go
new file mode 100644 (file)
index 0000000..7e79056
--- /dev/null
@@ -0,0 +1,46 @@
+package main
+
+func init() {
+       addTestCases(httpserverTests)
+}
+
+var httpserverTests = []testCase{
+       {
+               Name: "httpserver.0",
+               Fn:   httpserver,
+               In: `package main
+
+import "http"
+
+func f(xyz http.ResponseWriter, abc *http.Request, b string) {
+       xyz.Hijack()
+       xyz.Flush()
+       go xyz.Hijack()
+       defer xyz.Flush()
+       _ = xyz.UsingTLS()
+       _ = true == xyz.UsingTLS()
+       _ = xyz.RemoteAddr()
+       _ = xyz.RemoteAddr() == "hello"
+       if xyz.UsingTLS() {
+       }
+}
+`,
+               Out: `package main
+
+import "http"
+
+func f(xyz http.ResponseWriter, abc *http.Request, b string) {
+       xyz.(http.Hijacker).Hijack()
+       xyz.(http.Flusher).Flush()
+       go xyz.(http.Hijacker).Hijack()
+       defer xyz.(http.Flusher).Flush()
+       _ = abc.TLS != nil
+       _ = true == (abc.TLS != nil)
+       _ = abc.RemoteAddr
+       _ = abc.RemoteAddr == "hello"
+       if abc.TLS != nil {
+       }
+}
+`,
+       },
+}
diff --git a/src/cmd/gofix/main.go b/src/cmd/gofix/main.go
new file mode 100644 (file)
index 0000000..40c86e8
--- /dev/null
@@ -0,0 +1,179 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "bytes"
+       "flag"
+       "fmt"
+       "go/parser"
+       "go/printer"
+       "go/scanner"
+       "go/token"
+       "io/ioutil"
+       "os"
+       "path/filepath"
+       "sort"
+       "strings"
+)
+
+var (
+       fset     = token.NewFileSet()
+       exitCode = 0
+)
+
+var allowedRewrites = flag.String("r", "",
+       "restrict the rewrites to this comma-separated list")
+
+var allowed map[string]bool
+
+func usage() {
+       fmt.Fprintf(os.Stderr, "usage: gofix [-r fixname,...] [path ...]\n")
+       flag.PrintDefaults()
+       fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
+       for _, f := range fixes {
+               fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
+               desc := strings.TrimSpace(f.desc)
+               desc = strings.Replace(desc, "\n", "\n\t", -1)
+               fmt.Fprintf(os.Stderr, "\t%s\n", desc)
+       }
+       os.Exit(2)
+}
+
+func main() {
+       sort.Sort(fixes)
+
+       flag.Usage = usage
+       flag.Parse()
+
+       if *allowedRewrites != "" {
+               allowed = make(map[string]bool)
+               for _, f := range strings.Split(*allowedRewrites, ",", -1) {
+                       allowed[f] = true
+               }
+       }
+
+       if flag.NArg() == 0 {
+               if err := processFile("standard input", true); err != nil {
+                       report(err)
+               }
+               os.Exit(exitCode)
+       }
+
+       for i := 0; i < flag.NArg(); i++ {
+               path := flag.Arg(i)
+               switch dir, err := os.Stat(path); {
+               case err != nil:
+                       report(err)
+               case dir.IsRegular():
+                       if err := processFile(path, false); err != nil {
+                               report(err)
+                       }
+               case dir.IsDirectory():
+                       walkDir(path)
+               }
+       }
+
+       os.Exit(exitCode)
+}
+
+const (
+       tabWidth    = 8
+       parserMode  = parser.ParseComments
+       printerMode = printer.TabIndent
+)
+
+
+func processFile(filename string, useStdin bool) os.Error {
+       var f *os.File
+       var err os.Error
+
+       if useStdin {
+               f = os.Stdin
+       } else {
+               f, err = os.Open(filename, os.O_RDONLY, 0)
+               if err != nil {
+                       return err
+               }
+               defer f.Close()
+       }
+
+       src, err := ioutil.ReadAll(f)
+       if err != nil {
+               return err
+       }
+
+       file, err := parser.ParseFile(fset, filename, src, parserMode)
+       if err != nil {
+               return err
+       }
+
+       fixed := false
+       var buf bytes.Buffer
+       for _, fix := range fixes {
+               if allowed != nil && !allowed[fix.desc] {
+                       continue
+               }
+               if fix.f(file) {
+                       fixed = true
+                       fmt.Fprintf(&buf, " %s", fix.name)
+               }
+       }
+       if !fixed {
+               return nil
+       }
+       fmt.Fprintf(os.Stderr, "%s: %s\n", filename, buf.String()[1:])
+
+       buf.Reset()
+       _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
+       if err != nil {
+               return err
+       }
+
+       if useStdin {
+               os.Stdout.Write(buf.Bytes())
+               return nil
+       }
+
+       return ioutil.WriteFile(f.Name(), buf.Bytes(), 0)
+}
+
+func report(err os.Error) {
+       scanner.PrintError(os.Stderr, err)
+       exitCode = 2
+}
+
+func walkDir(path string) {
+       v := make(fileVisitor)
+       go func() {
+               filepath.Walk(path, v, v)
+               close(v)
+       }()
+       for err := range v {
+               if err != nil {
+                       report(err)
+               }
+       }
+}
+
+type fileVisitor chan os.Error
+
+func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool {
+       return true
+}
+
+func (v fileVisitor) VisitFile(path string, f *os.FileInfo) {
+       if isGoFile(f) {
+               v <- nil // synchronize error handler
+               if err := processFile(path, false); err != nil {
+                       v <- err
+               }
+       }
+}
+
+func isGoFile(f *os.FileInfo) bool {
+       // ignore non-Go files
+       return f.IsRegular() && !strings.HasPrefix(f.Name, ".") && strings.HasSuffix(f.Name, ".go")
+}
diff --git a/src/cmd/gofix/main_test.go b/src/cmd/gofix/main_test.go
new file mode 100644 (file)
index 0000000..597bff2
--- /dev/null
@@ -0,0 +1,102 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "bytes"
+       "go/ast"
+       "go/parser"
+       "go/printer"
+       "testing"
+)
+
+type testCase struct {
+       Name string
+       Fn   func(*ast.File) bool
+       In   string
+       Out  string
+}
+
+var testCases []testCase
+
+func addTestCases(t []testCase) {
+       testCases = append(testCases, t...)
+}
+
+func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) {
+       file, err := parser.ParseFile(fset, desc, in, parserMode)
+       if err != nil {
+               t.Errorf("%s: parsing: %v", desc, err)
+               return
+       }
+
+       var buf bytes.Buffer
+       buf.Reset()
+       _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
+       if err != nil {
+               t.Errorf("%s: printing: %v", desc, err)
+               return
+       }
+       if s := buf.String(); in != s {
+               t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
+                       desc, desc, in, desc, s)
+               return
+       }
+
+       if fn == nil {
+               for _, fix := range fixes {
+                       if fix.f(file) {
+                               fixed = true
+                       }
+               }
+       } else {
+               fixed = fn(file)
+       }
+
+       buf.Reset()
+       _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
+       if err != nil {
+               t.Errorf("%s: printing: %v", desc, err)
+               return
+       }
+
+       return buf.String(), fixed, true
+}
+
+func TestRewrite(t *testing.T) {
+       for _, tt := range testCases {
+               // Apply fix: should get tt.Out.
+               out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In)
+               if !ok {
+                       continue
+               }
+
+               if out != tt.Out {
+                       t.Errorf("%s: incorrect output.\n--- have\n%s\n--- want\n%s", tt.Name, out, tt.Out)
+                       continue
+               }
+
+               if changed := out != tt.In; changed != fixed {
+                       t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed)
+                       continue
+               }
+
+               // Should not change if run again.
+               out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out)
+               if !ok {
+                       continue
+               }
+
+               if fixed2 {
+                       t.Errorf("%s: applied fixes during second round", tt.Name)
+                       continue
+               }
+
+               if out2 != out {
+                       t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
+                               tt.Name, out, out2)
+               }
+       }
+}
index 3edb1e60bd0e9846a7ab944dd0ae84a6e4ab0543..24b304346de177c2be4defe763e04b398b28c920 100644 (file)
@@ -155,6 +155,7 @@ DIRS=\
        ../cmd/cgo\
        ../cmd/ebnflint\
        ../cmd/godoc\
+       ../cmd/gofix\
        ../cmd/gofmt\
        ../cmd/gotype\
        ../cmd/goinstall\