]> Cypherpunks repositories - gostls13.git/commitdiff
gofmt: add -r flag to rewrite source code according to pattern
authorRuss Cox <rsc@golang.org>
Fri, 20 Nov 2009 23:09:54 +0000 (15:09 -0800)
committerRuss Cox <rsc@golang.org>
Fri, 20 Nov 2009 23:09:54 +0000 (15:09 -0800)
a little slow, but usable (speed unchanged when not using -r)

tweak go/printer to handle nodes without line numbers
more gracefully in a couple cases.

R=gri
https://golang.org/cl/156103

src/cmd/gofmt/Makefile
src/cmd/gofmt/gofmt.go
src/cmd/gofmt/rewrite.go [new file with mode: 0644]
src/pkg/go/printer/nodes.go

index a93b8c37266f1e9820409e111d8ba86e3b89e235..dbc134f88ec51901f4805eb5e53d9b5c8d2553d4 100644 (file)
@@ -7,6 +7,7 @@ include $(GOROOT)/src/Make.$(GOARCH)
 TARG=gofmt
 GOFILES=\
        gofmt.go\
+       rewrite.go\
 
 include $(GOROOT)/src/Make.cmd
 
index bec4c88918680222e95cbd12ecbe0d6118d33620..d7c96dc3ac943d1e8216da67668a73d5df4db041 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bytes";
        "flag";
        "fmt";
+       "go/ast";
        "go/parser";
        "go/printer";
        "go/scanner";
@@ -20,8 +21,9 @@ import (
 
 var (
        // main operation modes
-       list    = flag.Bool("l", false, "list files whose formatting differs from gofmt's");
-       write   = flag.Bool("w", false, "write result to (source) file instead of stdout");
+       list            = flag.Bool("l", false, "list files whose formatting differs from gofmt's");
+       write           = flag.Bool("w", false, "write result to (source) file instead of stdout");
+       rewriteRule     = flag.String("r", "", "rewrite rule (e.g., 'α[β:len(α)] -> α[β:]')");
 
        // debugging support
        comments        = flag.Bool("comments", true, "print comments");
@@ -34,6 +36,8 @@ var (
 
 
 var exitCode = 0
+var rewrite func(*ast.File) *ast.File
+
 
 func report(err os.Error) {
        scanner.PrintError(os.Stderr, err);
@@ -86,6 +90,10 @@ func processFile(filename string) os.Error {
                return err
        }
 
+       if rewrite != nil {
+               file = rewrite(file)
+       }
+
        var res bytes.Buffer;
        _, err = (&printer.Config{printerMode(), *tabwidth, nil}).Fprint(&res, file);
        if err != nil {
@@ -154,6 +162,8 @@ func main() {
                os.Exit(2);
        }
 
+       initRewrite();
+
        if flag.NArg() == 0 {
                if err := processFile("/dev/stdin"); err != nil {
                        report(err)
diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go
new file mode 100644 (file)
index 0000000..9399bcd
--- /dev/null
@@ -0,0 +1,226 @@
+// Copyright 2009 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "fmt";
+       "go/ast";
+       "go/parser";
+       "go/token";
+       "os";
+       "reflect";
+       "strings";
+       "unicode";
+       "utf8";
+)
+
+
+func initRewrite() {
+       if *rewriteRule == "" {
+               return
+       }
+       f := strings.Split(*rewriteRule, "->", 0);
+       if len(f) != 2 {
+               fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n");
+               os.Exit(2);
+       }
+       pattern := parseExpr(f[0], "pattern");
+       replace := parseExpr(f[1], "replacement");
+       rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) };
+}
+
+
+// parseExpr parses s as an expression.
+// It might make sense to expand this to allow statement patterns,
+// but there are problems with preserving formatting and also
+// with what a wildcard for a statement looks like.
+func parseExpr(s string, what string) ast.Expr {
+       stmts, err := parser.ParseStmtList("input", s);
+       if err != nil {
+               fmt.Fprintf(os.Stderr, "parsing %s %s: %s\n", what, s, err);
+               os.Exit(2);
+       }
+       if len(stmts) != 1 {
+               fmt.Fprintf(os.Stderr, "%s must be single expression\n", what);
+               os.Exit(2);
+       }
+       x, ok := stmts[0].(*ast.ExprStmt);
+       if !ok {
+               fmt.Fprintf(os.Stderr, "%s must be single expression\n", what);
+               os.Exit(2);
+       }
+       return x.X;
+}
+
+
+// rewriteFile applys the rewrite rule pattern -> replace to an entire file.
+func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
+       m := make(map[string]reflect.Value);
+       pat := reflect.NewValue(pattern);
+       repl := reflect.NewValue(replace);
+       var f func(val reflect.Value) reflect.Value;    // f is recursive
+       f = func(val reflect.Value) reflect.Value {
+               for k := range m {
+                       m[k] = nil, false
+               }
+               if match(m, pat, val) {
+                       return subst(m, repl)
+               }
+               return apply(f, val);
+       };
+       return apply(f, reflect.NewValue(p)).Interface().(*ast.File);
+}
+
+
+var positionType = reflect.Typeof(token.Position{})
+var zeroPosition = reflect.NewValue(token.Position{})
+var identType = reflect.Typeof((*ast.Ident)(nil))
+
+
+func isWildcard(s string) bool {
+       rune, _ := utf8.DecodeRuneInString(s);
+       return unicode.Is(unicode.Greek, rune) && unicode.IsLower(rune);
+}
+
+
+// apply replaces each AST field x in val with f(x), returning val.
+// To avoid extra conversions, f operates on the reflect.Value form.
+func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
+       if val == nil {
+               return nil
+       }
+       switch v := reflect.Indirect(val).(type) {
+       case *reflect.SliceValue:
+               for i := 0; i < v.Len(); i++ {
+                       e := v.Elem(i);
+                       e.SetValue(f(e));
+               }
+       case *reflect.StructValue:
+               for i := 0; i < v.NumField(); i++ {
+                       e := v.Field(i);
+                       e.SetValue(f(e));
+               }
+       case *reflect.InterfaceValue:
+               e := v.Elem();
+               v.SetValue(f(e));
+       }
+       return val;
+}
+
+
+// match returns true if pattern matches val,
+// recording wildcard submatches in m.
+// If m == nil, match checks whether pattern == val.
+func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
+       // Wildcard matches any expression.  If it appears multiple
+       // times in the pattern, it must match the same expression
+       // each time.
+       if m != nil && pattern.Type() == identType {
+               name := pattern.Interface().(*ast.Ident).Value;
+               if isWildcard(name) {
+                       if old, ok := m[name]; ok {
+                               return match(nil, old, val)
+                       }
+                       m[name] = val;
+                       return true;
+               }
+       }
+
+       // Otherwise, the expressions must match recursively.
+       if pattern == nil || val == nil {
+               return pattern == nil && val == nil
+       }
+       if pattern.Type() != val.Type() {
+               return false
+       }
+
+       // Token positions need not match.
+       if pattern.Type() == positionType {
+               return true
+       }
+
+       p := reflect.Indirect(pattern);
+       v := reflect.Indirect(val);
+
+       switch p := p.(type) {
+       case *reflect.SliceValue:
+               v := v.(*reflect.SliceValue);
+               for i := 0; i < p.Len(); i++ {
+                       if !match(m, p.Elem(i), v.Elem(i)) {
+                               return false
+                       }
+               }
+               return true;
+
+       case *reflect.StructValue:
+               v := v.(*reflect.StructValue);
+               for i := 0; i < p.NumField(); i++ {
+                       if !match(m, p.Field(i), v.Field(i)) {
+                               return false
+                       }
+               }
+               return true;
+
+       case *reflect.InterfaceValue:
+               v := v.(*reflect.InterfaceValue);
+               return match(m, p.Elem(), v.Elem());
+       }
+
+       // Handle token integers, etc.
+       return p.Interface() == v.Interface();
+}
+
+
+// subst returns a copy of pattern with values from m substituted in place of wildcards.
+// if m == nil, subst returns a copy of pattern.
+// Either way, the returned value has no valid line number information.
+func subst(m map[string]reflect.Value, pattern reflect.Value) reflect.Value {
+       if pattern == nil {
+               return nil
+       }
+
+       // Wildcard gets replaced with map value.
+       if m != nil && pattern.Type() == identType {
+               name := pattern.Interface().(*ast.Ident).Value;
+               if isWildcard(name) {
+                       if old, ok := m[name]; ok {
+                               return subst(nil, old)
+                       }
+               }
+       }
+
+       if pattern.Type() == positionType {
+               return zeroPosition
+       }
+
+       // Otherwise copy.
+       switch p := pattern.(type) {
+       case *reflect.SliceValue:
+               v := reflect.MakeSlice(p.Type().(*reflect.SliceType), p.Len(), p.Len());
+               for i := 0; i < p.Len(); i++ {
+                       v.Elem(i).SetValue(subst(m, p.Elem(i)))
+               }
+               return v;
+
+       case *reflect.StructValue:
+               v := reflect.MakeZero(p.Type()).(*reflect.StructValue);
+               for i := 0; i < p.NumField(); i++ {
+                       v.Field(i).SetValue(subst(m, p.Field(i)))
+               }
+               return v;
+
+       case *reflect.PtrValue:
+               v := reflect.MakeZero(p.Type()).(*reflect.PtrValue);
+               v.PointTo(subst(m, p.Elem()));
+               return v;
+
+       case *reflect.InterfaceValue:
+               v := reflect.MakeZero(p.Type()).(*reflect.InterfaceValue);
+               v.SetValue(subst(m, p.Elem()));
+               return v;
+       }
+
+       return pattern;
+}
index 6304830bd35811946931098074ed7d9ff9dff9fe..1c7460313a664a3fe8c66437016c05c7daf5e8eb 100644 (file)
@@ -52,6 +52,17 @@ func (p *printer) linebreak(line, min, max int, ws whiteSpace, newSection bool)
        case n > max:
                n = max
        }
+
+       // TODO(gri): try to avoid direct manipulation of p.pos
+       // demo of why this is necessary: run gofmt -r 'i < i -> i < j' x.go on this x.go:
+       //      package main
+       //      func main() {
+       //              i < i;
+       //              j < 10;
+       //      }
+       //
+       p.pos.Line += n;
+
        if n > 0 {
                p.print(ws);
                if newSection {
@@ -199,7 +210,7 @@ func (p *printer) exprList(prev token.Position, list []ast.Expr, depth int, mode
                        if mode&commaSep != 0 {
                                p.print(token.COMMA)
                        }
-                       if prev < line {
+                       if prev < line && prev > 0 && line > 0 {
                                if p.linebreak(line, 1, 2, ws, true) {
                                        ws = ignore;
                                        *multiLine = true;
@@ -564,8 +575,7 @@ func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int, multiL
        xline := p.pos.Line;    // before the operator (it may be on the next line!)
        yline := x.Y.Pos().Line;
        p.print(x.OpPos, x.Op);
-       if xline != yline {
-               //println(x.OpPos.String());
+       if xline != yline && xline > 0 && yline > 0 {
                // at least one line break, but respect an extra empty line
                // in the source
                if p.linebreak(yline, 1, 2, ws, true) {