]> Cypherpunks repositories - gostls13.git/commitdiff
strings: implement a faster byte->byte Replacer
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 3 Oct 2011 20:12:01 +0000 (13:12 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 3 Oct 2011 20:12:01 +0000 (13:12 -0700)
When all old & new string values are single bytes,
byteReplacer is now used, instead of the generic
algorithm.

BenchmarkGenericMatch       10000  102519 ns/op
BenchmarkByteByteMatch    1000000    2178 ns/op

fast path, when nothing matches:
BenchmarkByteByteNoMatch  1000000    1109 ns/op

comparisons to multiple Replace calls:
BenchmarkByteByteReplaces  100000   16164 ns/op

comparison to strings.Map:
BenchmarkByteByteMap       500000    5454 ns/op

R=rsc
CC=golang-dev
https://golang.org/cl/5175050

src/pkg/http/cookie.go
src/pkg/strings/export_test.go [new file with mode: 0644]
src/pkg/strings/replace.go
src/pkg/strings/replace_test.go

index fe70431bbbcc046c04ab014e81d61890e71f55c0..69350143248540db7f7622754aa65dbafdb922be 100644 (file)
@@ -207,17 +207,16 @@ func readCookies(h Header, filter string) []*Cookie {
        return cookies
 }
 
+var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-")
+
 func sanitizeName(n string) string {
-       n = strings.Replace(n, "\n", "-", -1)
-       n = strings.Replace(n, "\r", "-", -1)
-       return n
+       return cookieNameSanitizer.Replace(n)
 }
 
+var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ")
+
 func sanitizeValue(v string) string {
-       v = strings.Replace(v, "\n", " ", -1)
-       v = strings.Replace(v, "\r", " ", -1)
-       v = strings.Replace(v, ";", " ", -1)
-       return v
+       return cookieValueSanitizer.Replace(v)
 }
 
 func unquoteCookieValue(v string) string {
diff --git a/src/pkg/strings/export_test.go b/src/pkg/strings/export_test.go
new file mode 100644 (file)
index 0000000..dcfec51
--- /dev/null
@@ -0,0 +1,9 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package strings
+
+func (r *Replacer) Replacer() interface{} {
+       return r.r
+}
index cf2c023be062da917cf6cbf49b6e5f96478bfca2..8eab12d3c832ca5704f96b2127d4498320ac6551 100644 (file)
@@ -9,20 +9,24 @@ import (
        "os"
 )
 
-// Can't import ioutil for ioutil.Discard, due to ioutil/tempfile.go -> strconv -> strings
-var discard io.Writer = devNull(0)
-
-type devNull int
+// A Replacer replaces a list of strings with replacements.
+type Replacer struct {
+       r replacer
+}
 
-func (devNull) Write(p []byte) (int, os.Error) {
-       return len(p), nil
+// replacer is the interface that a replacement algorithm needs to implement.
+type replacer interface {
+       Replace(s string) string
+       WriteString(w io.Writer, s string) (n int, err os.Error)
 }
 
-type pair struct{ old, new string }
+// byteBitmap represents bytes which are sought for replacement.
+// byteBitmap is 256 bits wide, with a bit set for each old byte to be
+// replaced.
+type byteBitmap [256 / 32]uint32
 
-// A Replacer replaces a list of strings with replacements.
-type Replacer struct {
-       p []pair
+func (m *byteBitmap) set(b byte) {
+       m[b>>5] |= uint32(1 << (b & 31))
 }
 
 // NewReplacer returns a new Replacer from a list of old, new string pairs.
@@ -31,14 +35,51 @@ func NewReplacer(oldnew ...string) *Replacer {
        if len(oldnew)%2 == 1 {
                panic("strings.NewReplacer: odd argument count")
        }
-       r := new(Replacer)
-       for len(oldnew) >= 2 {
-               r.p = append(r.p, pair{oldnew[0], oldnew[1]})
+
+       var bb byteReplacer
+       var gen genericReplacer
+
+       allOldBytes, allNewBytes := true, true
+       for len(oldnew) > 0 {
+               old, new := oldnew[0], oldnew[1]
                oldnew = oldnew[2:]
+               if len(old) != 1 {
+                       allOldBytes = false
+               }
+               if len(new) != 1 {
+                       allNewBytes = false
+               }
+               gen.p = append(gen.p, pair{old, new})
+               if allOldBytes && allNewBytes {
+                       bb.old.set(old[0])
+                       bb.new[old[0]] = new[0]
+               }
        }
-       return r
+
+       if allOldBytes && allNewBytes {
+               return &Replacer{r: &bb}
+       }
+       return &Replacer{r: &gen}
+}
+
+// Replace returns a copy of s with all replacements performed.
+func (r *Replacer) Replace(s string) string {
+       return r.r.Replace(s)
+}
+
+// WriteString writes s to w with all replacements performed.
+func (r *Replacer) WriteString(w io.Writer, s string) (n int, err os.Error) {
+       return r.r.WriteString(w, s)
+}
+
+// genericReplacer is the fully generic (and least optimized) algorithm.
+// It's used as a fallback when nothing faster can be used.
+type genericReplacer struct {
+       p []pair
 }
 
+type pair struct{ old, new string }
+
 type appendSliceWriter struct {
        b []byte
 }
@@ -48,8 +89,7 @@ func (w *appendSliceWriter) Write(p []byte) (int, os.Error) {
        return len(p), nil
 }
 
-// Replace returns a copy of s with all replacements performed.
-func (r *Replacer) Replace(s string) string {
+func (r *genericReplacer) Replace(s string) string {
        // TODO(bradfitz): optimized version
        n, _ := r.WriteString(discard, s)
        w := appendSliceWriter{make([]byte, 0, n)}
@@ -57,19 +97,28 @@ func (r *Replacer) Replace(s string) string {
        return string(w.b)
 }
 
-// WriteString writes s to w with all replacements performed.
-func (r *Replacer) WriteString(w io.Writer, s string) (n int, err os.Error) {
+func (r *genericReplacer) WriteString(w io.Writer, s string) (n int, err os.Error) {
+       lastEmpty := false // the last replacement was of the empty string
 Input:
        // TODO(bradfitz): optimized version
        for i := 0; i < len(s); {
                for _, p := range r.p {
+                       if p.old == "" && lastEmpty {
+                               // Don't let old match twice in a row.
+                               // (it doesn't advance the input and
+                               // would otherwise loop forever)
+                               continue
+                       }
                        if HasPrefix(s[i:], p.old) {
-                               wn, err := w.Write([]byte(p.new))
-                               n += wn
-                               if err != nil {
-                                       return n, err
+                               if p.new != "" {
+                                       wn, err := w.Write([]byte(p.new))
+                                       n += wn
+                                       if err != nil {
+                                               return n, err
+                                       }
                                }
                                i += len(p.old)
+                               lastEmpty = p.old == ""
                                continue Input
                        }
                }
@@ -80,5 +129,81 @@ Input:
                }
                i++
        }
+
+       // Final empty match at end.
+       for _, p := range r.p {
+               if p.old == "" {
+                       if p.new != "" {
+                               wn, err := w.Write([]byte(p.new))
+                               n += wn
+                               if err != nil {
+                                       return n, err
+                               }
+                       }
+                       break
+               }
+       }
+
+       return n, nil
+}
+
+// byteReplacer is the implementation that's used when all the "old"
+// and "new" values are single ASCII bytes.
+type byteReplacer struct {
+       // old has a bit set for each old byte that should be replaced.
+       old byteBitmap
+
+       // replacement byte, indexed by old byte. only valid if
+       // corresponding old bit is set.
+       new [256]byte
+}
+
+func (r *byteReplacer) Replace(s string) string {
+       var buf []byte // lazily allocated
+       for i := 0; i < len(s); i++ {
+               b := s[i]
+               if r.old[b>>5]&uint32(1<<(b&31)) != 0 {
+                       if buf == nil {
+                               buf = []byte(s)
+                       }
+                       buf[i] = r.new[b]
+               }
+       }
+       if buf == nil {
+               return s
+       }
+       return string(buf)
+}
+
+func (r *byteReplacer) WriteString(w io.Writer, s string) (n int, err os.Error) {
+       bufsize := 32 << 10
+       if len(s) < bufsize {
+               bufsize = len(s)
+       }
+       buf := make([]byte, bufsize)
+
+       for len(s) > 0 {
+               ncopy := copy(buf, s[:])
+               s = s[ncopy:]
+               for i, b := range buf[:ncopy] {
+                       if r.old[b>>5]&uint32(1<<(b&31)) != 0 {
+                               buf[i] = r.new[b]
+                       }
+               }
+               wn, err := w.Write(buf[:ncopy])
+               n += wn
+               if err != nil {
+                       return n, err
+               }
+       }
        return n, nil
 }
+
+// strings is too low-level to import io/ioutil
+var discard io.Writer = devNull(0)
+
+type devNull int
+
+func (devNull) Write(p []byte) (int, os.Error) {
+       return len(p), nil
+}
index a7677a132b0d63db99b9e67d2060f6c01c65ab9a..20e734f6cda7b1897ee1b00f1285fb38563a5300 100644 (file)
@@ -5,12 +5,17 @@
 package strings_test
 
 import (
+       "bytes"
+       "fmt"
+       "log"
        . "strings"
        "testing"
 )
 
+var _ = log.Printf
+
 type ReplacerTest struct {
-       m   *Replacer
+       r   *Replacer
        in  string
        out string
 }
@@ -31,6 +36,10 @@ var replacer = NewReplacer("aaa", "3[aaa]", "aa", "2[aa]", "a", "1[a]", "i", "i"
        "longerst", "most long", "longer", "medium", "long", "short",
        "X", "Y", "Y", "Z")
 
+var capitalLetters = NewReplacer("a", "A", "b", "B")
+
+var blankToXReplacer = NewReplacer("", "X", "o", "O")
+
 var ReplacerTests = []ReplacerTest{
        {htmlEscaper, "No changes", "No changes"},
        {htmlEscaper, "I <3 escaping & stuff", "I &lt;3 escaping &amp; stuff"},
@@ -38,38 +47,98 @@ var ReplacerTests = []ReplacerTest{
        {replacer, "fooaaabar", "foo3[aaa]b1[a]r"},
        {replacer, "long, longerst, longer", "short, most long, medium"},
        {replacer, "XiX", "YiY"},
+       {capitalLetters, "brad", "BrAd"},
+       {capitalLetters, Repeat("a", (32<<10)+123), Repeat("A", (32<<10)+123)},
+       {blankToXReplacer, "oo", "XOXOX"},
 }
 
 func TestReplacer(t *testing.T) {
        for i, tt := range ReplacerTests {
-               if s := tt.m.Replace(tt.in); s != tt.out {
+               if s := tt.r.Replace(tt.in); s != tt.out {
                        t.Errorf("%d. Replace(%q) = %q, want %q", i, tt.in, s, tt.out)
                }
+               var buf bytes.Buffer
+               n, err := tt.r.WriteString(&buf, tt.in)
+               if err != nil {
+                       t.Errorf("%d. WriteString: %v", i, err)
+                       continue
+               }
+               got := buf.String()
+               if got != tt.out {
+                       t.Errorf("%d. WriteString(%q) wrote %q, want %q", i, tt.in, got, tt.out)
+                       continue
+               }
+               if n != len(tt.out) {
+                       t.Errorf("%d. WriteString(%q) wrote correct string but reported %d bytes; want %d (%q)",
+                               i, tt.in, n, len(tt.out), tt.out)
+               }
        }
 }
 
-var slowReplacer = NewReplacer("&&", "&amp;", "<<", "&lt;", ">>", "&gt;", "\"\"", "&quot;", "''", "&apos;")
+// pickAlgorithmTest is a test that verifies that given input for a
+// Replacer that we pick the correct algorithm.
+type pickAlgorithmTest struct {
+       r    *Replacer
+       want string // name of algorithm
+}
+
+var pickAlgorithmTests = []pickAlgorithmTest{
+       {capitalLetters, "*strings.byteReplacer"},
+       {NewReplacer("a", "A", "b", "Bb"), "*strings.genericReplacer"},
+}
 
-func BenchmarkReplacerSingleByte(b *testing.B) {
-       str := "I <3 benchmarking html & other stuff too >:D"
-       n := 0
+func TestPickAlgorithm(t *testing.T) {
+       for i, tt := range pickAlgorithmTests {
+               got := fmt.Sprintf("%T", tt.r.Replacer())
+               if got != tt.want {
+                       t.Errorf("%d. algorithm = %s, want %s", i, got, tt.want)
+               }
+       }
+}
+
+func BenchmarkGenericMatch(b *testing.B) {
+       str := Repeat("A", 100) + Repeat("B", 100)
+       generic := NewReplacer("a", "A", "b", "B", "12", "123") // varying lengths forces generic
+       for i := 0; i < b.N; i++ {
+               generic.Replace(str)
+       }
+}
+
+func BenchmarkByteByteNoMatch(b *testing.B) {
+       str := Repeat("A", 100) + Repeat("B", 100)
        for i := 0; i < b.N; i++ {
-               n += len(htmlEscaper.Replace(str))
+               capitalLetters.Replace(str)
        }
 }
 
-func BenchmarkReplaceMap(b *testing.B) {
-       str := "I <<3 benchmarking html && other stuff too >>:D"
-       n := 0
+func BenchmarkByteByteMatch(b *testing.B) {
+       str := Repeat("a", 100) + Repeat("b", 100)
        for i := 0; i < b.N; i++ {
-               n += len(slowReplacer.Replace(str))
+               capitalLetters.Replace(str)
        }
 }
 
-func BenchmarkOldHTTPHTMLReplace(b *testing.B) {
-       str := "I <3 benchmarking html & other stuff too >:D"
-       n := 0
+// BenchmarkByteByteReplaces compares byteByteImpl against multiple Replaces.
+func BenchmarkByteByteReplaces(b *testing.B) {
+       str := Repeat("a", 100) + Repeat("b", 100)
+       for i := 0; i < b.N; i++ {
+               Replace(Replace(str, "a", "A", -1), "b", "B", -1)
+       }
+}
+
+// BenchmarkByteByteMap compares byteByteImpl against Map.
+func BenchmarkByteByteMap(b *testing.B) {
+       str := Repeat("a", 100) + Repeat("b", 100)
+       fn := func(r int) int {
+               switch r {
+               case 'a':
+                       return int('A')
+               case 'b':
+                       return int('B')
+               }
+               return r
+       }
        for i := 0; i < b.N; i++ {
-               n += len(oldhtmlEscape(str))
+               Map(fn, str)
        }
 }