]> Cypherpunks repositories - gostls13.git/commitdiff
testing/regexp: use recover.
authorRob Pike <r@golang.org>
Thu, 1 Apr 2010 00:57:50 +0000 (17:57 -0700)
committerRob Pike <r@golang.org>
Thu, 1 Apr 2010 00:57:50 +0000 (17:57 -0700)
R=rsc
CC=golang-dev
https://golang.org/cl/816042

src/pkg/regexp/regexp.go
src/pkg/testing/regexp.go

index 9f0ee191a72be5bc401d11fddcbc57d622d0ee21..f8d03d743f0604d27b7132638effead318f5d359 100644 (file)
@@ -1,4 +1,3 @@
-// Copyright 2009 The Go Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
@@ -264,8 +263,7 @@ type parser struct {
        ch    int
 }
 
-func (p *parser) error(err os.Error) {
-       p.re = nil
+func (p *parser) error(err Error) {
        panic(err)
 }
 
index cd37699ce3e7aafb82eba409507e98971952335a..78d801d51be5134e32302fa2a498c46934c05b20 100644 (file)
@@ -34,13 +34,13 @@ var debug = false
 // Error codes returned by failures to parse an expression.
 var (
        ErrInternal            = "internal error"
-       ErrUnmatchedLpar       = "unmatched ''"
-       ErrUnmatchedRpar       = "unmatched ''"
+       ErrUnmatchedLpar       = "unmatched '('"
+       ErrUnmatchedRpar       = "unmatched ')'"
        ErrUnmatchedLbkt       = "unmatched '['"
        ErrUnmatchedRbkt       = "unmatched ']'"
        ErrBadRange            = "bad range in character class"
        ErrExtraneousBackslash = "extraneous backslash"
-       ErrBadClosure          = "repeated closure **, ++, etc."
+       ErrBadClosure          = "repeated closure (**, ++, etc.)"
        ErrBareClosure         = "closure applies to nothing"
        ErrBadBackslash        = "illegal backslash escape"
 )
@@ -267,12 +267,15 @@ func (re *Regexp) add(i instr) instr {
 
 type parser struct {
        re    *Regexp
-       error string
        nlpar int // number of unclosed lpars
        pos   int
        ch    int
 }
 
+func (p *parser) error(err string) {
+       panic(err)
+}
+
 const endOfFile = -1
 
 func (p *parser) c() int { return p.ch }
@@ -326,8 +329,7 @@ func (p *parser) charClass() instr {
                switch c := p.c(); c {
                case ']', endOfFile:
                        if left >= 0 {
-                               p.error = ErrBadRange
-                               return nil
+                               p.error(ErrBadRange)
                        }
                        // Is it [^\n]?
                        if cc.negate && len(cc.ranges) == 2 &&
@@ -339,21 +341,18 @@ func (p *parser) charClass() instr {
                        p.re.add(cc)
                        return cc
                case '-': // do this before backslash processing
-                       p.error = ErrBadRange
-                       return nil
+                       p.error(ErrBadRange)
                case '\\':
                        c = p.nextc()
                        switch {
                        case c == endOfFile:
-                               p.error = ErrExtraneousBackslash
-                               return nil
+                               p.error(ErrExtraneousBackslash)
                        case c == 'n':
                                c = '\n'
                        case specialcclass(c):
                        // c is as delivered
                        default:
-                               p.error = ErrBadBackslash
-                               return nil
+                               p.error(ErrBadBackslash)
                        }
                        fallthrough
                default:
@@ -370,8 +369,7 @@ func (p *parser) charClass() instr {
                                cc.addRange(left, c)
                                left = -1
                        default:
-                               p.error = ErrBadRange
-                               return nil
+                               p.error(ErrBadRange)
                        }
                }
        }
@@ -379,28 +377,19 @@ func (p *parser) charClass() instr {
 }
 
 func (p *parser) term() (start, end instr) {
-       // term() is the leaf of the recursion, so it's sufficient to pick off the
-       // error state here for early exit.
-       // The other functions (closure(), concatenation() etc.) assume
-       // it's safe to recur to here.
-       if p.error != "" {
-               return
-       }
        switch c := p.c(); c {
        case '|', endOfFile:
                return nil, nil
        case '*', '+':
-               p.error = ErrBareClosure
+               p.error(ErrBareClosure)
                return
        case ')':
                if p.nlpar == 0 {
-                       p.error = ErrUnmatchedRpar
-                       return
+                       p.error(ErrUnmatchedRpar)
                }
                return nil, nil
        case ']':
-               p.error = ErrUnmatchedRbkt
-               return
+               p.error(ErrUnmatchedRbkt)
        case '^':
                p.nextc()
                start = p.re.add(new(_Bot))
@@ -416,12 +405,8 @@ func (p *parser) term() (start, end instr) {
        case '[':
                p.nextc()
                start = p.charClass()
-               if p.error != "" {
-                       return
-               }
                if p.c() != ']' {
-                       p.error = ErrUnmatchedLbkt
-                       return
+                       p.error(ErrUnmatchedLbkt)
                }
                p.nextc()
                return start, start
@@ -432,8 +417,7 @@ func (p *parser) term() (start, end instr) {
                nbra := p.re.nbra
                start, end = p.regexp()
                if p.c() != ')' {
-                       p.error = ErrUnmatchedLpar
-                       return
+                       p.error(ErrUnmatchedLpar)
                }
                p.nlpar--
                p.nextc()
@@ -445,8 +429,7 @@ func (p *parser) term() (start, end instr) {
                ebra.n = nbra
                if start == nil {
                        if end == nil {
-                               p.error = ErrInternal
-                               return
+                               p.error(ErrInternal)
                        }
                        start = ebra
                } else {
@@ -458,15 +441,14 @@ func (p *parser) term() (start, end instr) {
                c = p.nextc()
                switch {
                case c == endOfFile:
-                       p.error = ErrExtraneousBackslash
+                       p.error(ErrExtraneousBackslash)
                        return
                case c == 'n':
                        c = '\n'
                case special(c):
                // c is as delivered
                default:
-                       p.error = ErrBadBackslash
-                       return
+                       p.error(ErrBadBackslash)
                }
                fallthrough
        default:
@@ -480,7 +462,7 @@ func (p *parser) term() (start, end instr) {
 
 func (p *parser) closure() (start, end instr) {
        start, end = p.term()
-       if start == nil || p.error != "" {
+       if start == nil {
                return
        }
        switch p.c() {
@@ -515,7 +497,7 @@ func (p *parser) closure() (start, end instr) {
        }
        switch p.nextc() {
        case '*', '+', '?':
-               p.error = ErrBadClosure
+               p.error(ErrBadClosure)
        }
        return
 }
@@ -523,9 +505,6 @@ func (p *parser) closure() (start, end instr) {
 func (p *parser) concatenation() (start, end instr) {
        for {
                nstart, nend := p.closure()
-               if p.error != "" {
-                       return
-               }
                switch {
                case nstart == nil: // end of this concatenation
                        if start == nil { // this is the empty string
@@ -545,9 +524,6 @@ func (p *parser) concatenation() (start, end instr) {
 
 func (p *parser) regexp() (start, end instr) {
        start, end = p.concatenation()
-       if p.error != "" {
-               return
-       }
        for {
                switch p.c() {
                default:
@@ -555,9 +531,6 @@ func (p *parser) regexp() (start, end instr) {
                case '|':
                        p.nextc()
                        nstart, nend := p.concatenation()
-                       if p.error != "" {
-                               return
-                       }
                        alt := new(_Alt)
                        p.re.add(alt)
                        alt.left = start
@@ -593,31 +566,31 @@ func (re *Regexp) eliminateNops() {
        }
 }
 
-func (re *Regexp) doParse() string {
+func (re *Regexp) doParse() {
        p := newParser(re)
        start := new(_Start)
        re.add(start)
        s, e := p.regexp()
-       if p.error != "" {
-               return p.error
-       }
        start.setNext(s)
        re.start = start
        e.setNext(re.add(new(_End)))
        re.eliminateNops()
-       return p.error
 }
 
 // CompileRegexp parses a regular expression and returns, if successful, a Regexp
 // object that can be used to match against text.
 func CompileRegexp(str string) (regexp *Regexp, error string) {
        regexp = new(Regexp)
+       // doParse will panic if there is a parse error.
+       defer func() {
+               if e := recover(); e != nil {
+                       regexp = nil
+                       error = e.(string) // Will re-panic if error was not a string, e.g. nil-pointer exception
+               }
+       }()
        regexp.expr = str
        regexp.inst = make([]instr, 0, 20)
-       error = regexp.doParse()
-       if error != "" {
-               regexp = nil
-       }
+       regexp.doParse()
        return
 }