]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/yacc: always import fmt, safely
authorRob Pike <r@golang.org>
Thu, 6 Sep 2012 21:58:37 +0000 (14:58 -0700)
committerRob Pike <r@golang.org>
Thu, 6 Sep 2012 21:58:37 +0000 (14:58 -0700)
The parser depends on it but the client might not import it, so make sure it's there.
Fixes #4038.

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6497094

src/cmd/yacc/yacc.go

index e94228152190a716d3f35487294eff599b4d5192..cca5570fb8dab122bbb0df714555c5569ad7166f 100644 (file)
@@ -51,6 +51,7 @@ import (
        "fmt"
        "os"
        "strings"
+       "unicode"
 )
 
 // the following are adjustable
@@ -153,6 +154,8 @@ var ftable *bufio.Writer    // y.go file
 var fcode = &bytes.Buffer{} // saved code
 var foutput *bufio.Writer   // y.output file
 
+var fmtImported bool // output file has recorded an import of "fmt"
+
 var oflag string  // -o [y.go]         - y.go file
 var vflag string  // -v [y.output]     - y.output file
 var lflag bool    // -l                        - disable line directives
@@ -1073,6 +1076,7 @@ out:
 
 //
 // saves code between %{ and %}
+// adds an import for __fmt__ the first time
 //
 func cpycode() {
        lno := lineno
@@ -1085,15 +1089,18 @@ func cpycode() {
        if !lflag {
                fmt.Fprintf(ftable, "\n//line %v:%v\n", infile, lineno)
        }
+       // accumulate until %}
+       code := make([]rune, 0, 1024)
        for c != EOF {
                if c == '%' {
                        c = getrune(finput)
                        if c == '}' {
+                               emitcode(code, lno+1)
                                return
                        }
-                       ftable.WriteRune('%')
+                       code = append(code, '%')
                }
-               ftable.WriteRune(c)
+               code = append(code, c)
                if c == '\n' {
                        lineno++
                }
@@ -1103,6 +1110,107 @@ func cpycode() {
        errorf("eof before %%}")
 }
 
+//
+// emits code saved up from between %{ and %}
+// called by cpycode
+// adds an import for __yyfmt__ after the package clause
+//
+func emitcode(code []rune, lineno int) {
+       for i, line := range lines(code) {
+               writecode(line)
+               if !fmtImported && isPackageClause(line) {
+                       fmt.Fprintln(ftable, `import __yyfmt__ "fmt"`)
+                       fmt.Fprintf(ftable, "//line %v:%v\n\t\t", infile, lineno+i)
+                       fmtImported = true
+               }
+       }
+}
+
+//
+// does this line look like a package clause?  not perfect: might be confused by early comments.
+//
+func isPackageClause(line []rune) bool {
+       line = skipspace(line)
+
+       // must be big enough.
+       if len(line) < len("package X\n") {
+               return false
+       }
+
+       // must start with "package"
+       for i, r := range []rune("package") {
+               if line[i] != r {
+                       return false
+               }
+       }
+       line = skipspace(line[len("package"):])
+
+       // must have another identifier.
+       if len(line) == 0 || (!unicode.IsLetter(line[0]) && line[0] != '_') {
+               return false
+       }
+       for len(line) > 0 {
+               if !unicode.IsLetter(line[0]) && !unicode.IsDigit(line[0]) && line[0] != '_' {
+                       break
+               }
+               line = line[1:]
+       }
+       line = skipspace(line)
+
+       // eol, newline, or comment must follow
+       if len(line) == 0 {
+               return true
+       }
+       if line[0] == '\r' || line[0] == '\n' {
+               return true
+       }
+       if len(line) >= 2 {
+               return line[0] == '/' && (line[1] == '/' || line[1] == '*')
+       }
+       return false
+}
+
+//
+// skip initial spaces
+//
+func skipspace(line []rune) []rune {
+       for len(line) > 0 {
+               if line[0] != ' ' && line[0] != '\t' {
+                       break
+               }
+               line = line[1:]
+       }
+       return line
+}
+
+//
+// break code into lines
+//
+func lines(code []rune) [][]rune {
+       l := make([][]rune, 0, 100)
+       for len(code) > 0 {
+               // one line per loop
+               var i int
+               for i = range code {
+                       if code[i] == '\n' {
+                               break
+                       }
+               }
+               l = append(l, code[:i+1])
+               code = code[i+1:]
+       }
+       return l
+}
+
+//
+// writes code to ftable
+//
+func writecode(code []rune) {
+       for _, r := range code {
+               ftable.WriteRune(r)
+       }
+}
+
 //
 // skip over comments
 // skipcom is called after reading a '/'
@@ -3115,7 +3223,7 @@ func $$Tokname(c int) string {
                        return $$Toknames[c-1]
                }
        }
-       return fmt.Sprintf("tok-%v", c)
+       return __yyfmt__.Sprintf("tok-%v", c)
 }
 
 func $$Statname(s int) string {
@@ -3124,7 +3232,7 @@ func $$Statname(s int) string {
                        return $$Statenames[s]
                }
        }
-       return fmt.Sprintf("state-%v", s)
+       return __yyfmt__.Sprintf("state-%v", s)
 }
 
 func $$lex1(lex $$Lexer, lval *$$SymType) int {
@@ -3157,7 +3265,7 @@ out:
                c = $$Tok2[1] /* unknown char */
        }
        if $$Debug >= 3 {
-               fmt.Printf("lex %U %s\n", uint(char), $$Tokname(c))
+               __yyfmt__.Printf("lex %U %s\n", uint(char), $$Tokname(c))
        }
        return c
 }
@@ -3184,7 +3292,7 @@ ret1:
 $$stack:
        /* put a state and value onto the stack */
        if $$Debug >= 4 {
-               fmt.Printf("char %v in %v\n", $$Tokname($$char), $$Statname($$state))
+               __yyfmt__.Printf("char %v in %v\n", $$Tokname($$char), $$Statname($$state))
        }
 
        $$p++
@@ -3253,8 +3361,8 @@ $$default:
                        $$lex.Error("syntax error")
                        Nerrs++
                        if $$Debug >= 1 {
-                               fmt.Printf("%s", $$Statname($$state))
-                               fmt.Printf("saw %s\n", $$Tokname($$char))
+                               __yyfmt__.Printf("%s", $$Statname($$state))
+                               __yyfmt__.Printf("saw %s\n", $$Tokname($$char))
                        }
                        fallthrough
 
@@ -3273,7 +3381,7 @@ $$default:
 
                                /* the current p has no shift on "error", pop stack */
                                if $$Debug >= 2 {
-                                       fmt.Printf("error recovery pops state %d\n", $$S[$$p].yys)
+                                       __yyfmt__.Printf("error recovery pops state %d\n", $$S[$$p].yys)
                                }
                                $$p--
                        }
@@ -3282,7 +3390,7 @@ $$default:
 
                case 3: /* no shift yet; clobber input char */
                        if $$Debug >= 2 {
-                               fmt.Printf("error recovery discards %s\n", $$Tokname($$char))
+                               __yyfmt__.Printf("error recovery discards %s\n", $$Tokname($$char))
                        }
                        if $$char == $$EofCode {
                                goto ret1
@@ -3294,7 +3402,7 @@ $$default:
 
        /* reduction by production $$n */
        if $$Debug >= 2 {
-               fmt.Printf("reduce %v in:\n\t%v\n", $$n, $$Statname($$state))
+               __yyfmt__.Printf("reduce %v in:\n\t%v\n", $$n, $$Statname($$state))
        }
 
        $$nt := $$n