]> Cypherpunks repositories - gostls13.git/commitdiff
gofix: fixes for os/signal changes
authorRobert Hencke <robert.hencke@gmail.com>
Wed, 29 Jun 2011 06:44:47 +0000 (16:44 +1000)
committerAndrew Gerrand <adg@golang.org>
Wed, 29 Jun 2011 06:44:47 +0000 (16:44 +1000)
Fixes #1971.

R=adg, rsc
CC=golang-dev
https://golang.org/cl/4630056

src/cmd/gofix/Makefile
src/cmd/gofix/fix.go
src/cmd/gofix/signal.go [new file with mode: 0644]
src/cmd/gofix/signal_test.go [new file with mode: 0644]

index 02d7463078bb50c4d20c86200c3121447af861d5..e74b639df4fa423ae79788077b3ebf48926bb140 100644 (file)
@@ -18,6 +18,7 @@ GOFILES=\
        osopen.go\
        procattr.go\
        reflect.go\
+       signal.go\
        sortslice.go\
        stringssplit.go\
        typecheck.go\
index 0852ce21edfe1132a06c8a660ca4c69287852873..c1c5a746ccf04bd1baab19221e8c223af2731b9c 100644 (file)
@@ -10,6 +10,7 @@ import (
        "go/token"
        "os"
        "strconv"
+       "strings"
 )
 
 type fix struct {
@@ -258,13 +259,28 @@ func walkBeforeAfter(x interface{}, before, after func(interface{})) {
 
 // imports returns true if f imports path.
 func imports(f *ast.File, path string) bool {
+       return importSpec(f, path) != nil
+}
+
+// importSpec returns the import spec if f imports path,
+// or nil otherwise.
+func importSpec(f *ast.File, path string) *ast.ImportSpec {
        for _, s := range f.Imports {
-               t, err := strconv.Unquote(s.Path.Value)
-               if err == nil && t == path {
-                       return true
+               if importPath(s) == path {
+                       return s
                }
        }
-       return false
+       return nil
+}
+
+// importPath returns the unquoted import path of s,
+// or "" if the path is not properly quoted.
+func importPath(s *ast.ImportSpec) string {
+       t, err := strconv.Unquote(s.Path.Value)
+       if err == nil {
+               return t
+       }
+       return ""
 }
 
 // isPkgDot returns true if t is the expression "pkg.name"
@@ -420,3 +436,138 @@ func newPkgDot(pos token.Pos, pkg, name string) ast.Expr {
                },
        }
 }
+
+// addImport adds the import path to the file f, if absent.
+func addImport(f *ast.File, path string) {
+       if imports(f, path) {
+               return
+       }
+
+       newImport := &ast.ImportSpec{
+               Path: &ast.BasicLit{
+                       Kind:  token.STRING,
+                       Value: strconv.Quote(path),
+               },
+       }
+
+       var impdecl *ast.GenDecl
+
+       // Find an import decl to add to.
+       for _, decl := range f.Decls {
+               gen, ok := decl.(*ast.GenDecl)
+
+               if ok && gen.Tok == token.IMPORT {
+                       impdecl = gen
+                       break
+               }
+       }
+
+       // No import decl found.  Add one.
+       if impdecl == nil {
+               impdecl = &ast.GenDecl{
+                       Tok: token.IMPORT,
+               }
+               f.Decls = append(f.Decls, nil)
+               copy(f.Decls[1:], f.Decls)
+               f.Decls[0] = impdecl
+       }
+
+       // Ensure the import decl has parentheses, if needed.
+       if len(impdecl.Specs) > 0 && !impdecl.Lparen.IsValid() {
+               impdecl.Lparen = impdecl.Pos()
+       }
+
+       // Assume the import paths are alphabetically ordered.
+       // If they are not, the result is ugly, but legal.
+       insertAt := len(impdecl.Specs) // default to end of specs
+       for i, spec := range impdecl.Specs {
+               impspec := spec.(*ast.ImportSpec)
+               if importPath(impspec) > path {
+                       insertAt = i
+                       break
+               }
+       }
+
+       impdecl.Specs = append(impdecl.Specs, nil)
+       copy(impdecl.Specs[insertAt+1:], impdecl.Specs[insertAt:])
+       impdecl.Specs[insertAt] = newImport
+
+       f.Imports = append(f.Imports, newImport)
+}
+
+// deleteImport deletes the import path from the file f, if present.
+func deleteImport(f *ast.File, path string) {
+       oldImport := importSpec(f, path)
+
+       // Find the import node that imports path, if any.
+       for i, decl := range f.Decls {
+               gen, ok := decl.(*ast.GenDecl)
+               if !ok || gen.Tok != token.IMPORT {
+                       continue
+               }
+               for j, spec := range gen.Specs {
+                       impspec := spec.(*ast.ImportSpec)
+
+                       if oldImport != impspec {
+                               continue
+                       }
+
+                       // We found an import spec that imports path.
+                       // Delete it.
+                       copy(gen.Specs[j:], gen.Specs[j+1:])
+                       gen.Specs = gen.Specs[:len(gen.Specs)-1]
+
+                       // If this was the last import spec in this decl,
+                       // delete the decl, too.
+                       if len(gen.Specs) == 0 {
+                               copy(f.Decls[i:], f.Decls[i+1:])
+                               f.Decls = f.Decls[:len(f.Decls)-1]
+                       } else if len(gen.Specs) == 1 {
+                               gen.Lparen = token.NoPos // drop parens
+                       }
+
+                       break
+               }
+       }
+
+       // Delete it from f.Imports.
+       for i, imp := range f.Imports {
+               if imp == oldImport {
+                       copy(f.Imports[i:], f.Imports[i+1:])
+                       f.Imports = f.Imports[:len(f.Imports)-1]
+                       break
+               }
+       }
+}
+
+func usesImport(f *ast.File, path string) (used bool) {
+       spec := importSpec(f, path)
+       if spec == nil {
+               return
+       }
+
+       name := spec.Name.String()
+       switch name {
+       case "<nil>":
+               // If the package name is not explicitly specified,
+               // make an educated guess. This is not guaranteed to be correct.
+               lastSlash := strings.LastIndex(path, "/")
+               if lastSlash == -1 {
+                       name = path
+               } else {
+                       name = path[lastSlash+1:]
+               }
+       case "_", ".":
+               // Not sure if this import is used - err on the side of caution.
+               return true
+       }
+
+       walk(f, func(n interface{}) {
+               sel, ok := n.(*ast.SelectorExpr)
+               if ok && isTopName(sel.X, name) {
+                       used = true
+               }
+       })
+
+       return
+}
diff --git a/src/cmd/gofix/signal.go b/src/cmd/gofix/signal.go
new file mode 100644 (file)
index 0000000..53c3388
--- /dev/null
@@ -0,0 +1,49 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "go/ast"
+       "strings"
+)
+
+func init() {
+       register(fix{
+               "signal",
+               signal,
+               `Adapt code to types moved from os/signal to signal.
+
+http://codereview.appspot.com/4437091
+`,
+       })
+}
+
+func signal(f *ast.File) (fixed bool) {
+       if !imports(f, "os/signal") {
+               return
+       }
+
+       walk(f, func(n interface{}) {
+               s, ok := n.(*ast.SelectorExpr)
+
+               if !ok || !isTopName(s.X, "signal") {
+                       return
+               }
+
+               sel := s.Sel.String()
+               if sel == "Signal" || sel == "UnixSignal" || strings.HasPrefix(sel, "SIG") {
+                       s.X = &ast.Ident{Name: "os"}
+                       fixed = true
+               }
+       })
+
+       if fixed {
+               addImport(f, "os")
+               if !usesImport(f, "os/signal") {
+                       deleteImport(f, "os/signal")
+               }
+       }
+       return
+}
diff --git a/src/cmd/gofix/signal_test.go b/src/cmd/gofix/signal_test.go
new file mode 100644 (file)
index 0000000..2500e9c
--- /dev/null
@@ -0,0 +1,96 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+       addTestCases(signalTests)
+}
+
+var signalTests = []testCase{
+       {
+               Name: "signal.0",
+               In: `package main
+
+import (
+       _ "a"
+       "os/signal"
+       _ "z"
+)
+
+type T1 signal.UnixSignal
+type T2 signal.Signal
+
+func f() {
+       _ = signal.SIGHUP
+       _ = signal.Incoming
+}
+`,
+               Out: `package main
+
+import (
+       _ "a"
+       "os"
+       "os/signal"
+       _ "z"
+)
+
+type T1 os.UnixSignal
+type T2 os.Signal
+
+func f() {
+       _ = os.SIGHUP
+       _ = signal.Incoming
+}
+`,
+       },
+       {
+               Name: "signal.1",
+               In: `package main
+
+import (
+       "os"
+       "os/signal"
+)
+
+func f() {
+       var _ os.Error
+       _ = signal.SIGHUP
+}
+`,
+               Out: `package main
+
+import "os"
+
+
+func f() {
+       var _ os.Error
+       _ = os.SIGHUP
+}
+`,
+       },
+       {
+               Name: "signal.2",
+               In: `package main
+
+import "os"
+import "os/signal"
+
+func f() {
+       var _ os.Error
+       _ = signal.SIGHUP
+}
+`,
+               Out: `package main
+
+import "os"
+
+
+func f() {
+       var _ os.Error
+       _ = os.SIGHUP
+}
+`,
+       },
+}