]> Cypherpunks repositories - gostls13.git/commitdiff
strings: use Rabin-Karp algorithm for LastIndex.
authorRui Ueyama <ruiu@google.com>
Mon, 1 Sep 2014 07:47:57 +0000 (17:47 +1000)
committerNigel Tao <nigeltao@golang.org>
Mon, 1 Sep 2014 07:47:57 +0000 (17:47 +1000)
benchmark                  old ns/op     new ns/op     delta
BenchmarkSingleMatch       49443         52275         +5.73%
BenchmarkIndex             28.8          27.4          -4.86%
BenchmarkLastIndex         14.5          14.0          -3.45%
BenchmarkLastIndexHard1    3982782       2309200       -42.02%
BenchmarkLastIndexHard2    3985562       2287715       -42.60%
BenchmarkLastIndexHard3    3555259       2282866       -35.79%

LGTM=josharian, nigeltao
R=golang-codereviews, ality, josharian, bradfitz, dave, nigeltao, gobot, nightlyone
CC=golang-codereviews
https://golang.org/cl/102560043

src/pkg/strings/strings.go
src/pkg/strings/strings_test.go

index 5f19695d3ffb75b1700c1063f4ea3785190e1af9..761f32a0681f2e008853995691553493f01e70cd 100644 (file)
@@ -43,13 +43,29 @@ func explode(s string, n int) []string {
 // primeRK is the prime base used in Rabin-Karp algorithm.
 const primeRK = 16777619
 
-// hashstr returns the hash and the appropriate multiplicative
+// hashStr returns the hash and the appropriate multiplicative
 // factor for use in Rabin-Karp algorithm.
-func hashstr(sep string) (uint32, uint32) {
+func hashStr(sep string) (uint32, uint32) {
        hash := uint32(0)
        for i := 0; i < len(sep); i++ {
                hash = hash*primeRK + uint32(sep[i])
+       }
+       var pow, sq uint32 = 1, primeRK
+       for i := len(sep); i > 0; i >>= 1 {
+               if i&1 != 0 {
+                       pow *= sq
+               }
+               sq *= sq
+       }
+       return hash, pow
+}
 
+// hashStrRev returns the hash of the reverse of sep and the
+// appropriate multiplicative factor for use in Rabin-Karp algorithm.
+func hashStrRev(sep string) (uint32, uint32) {
+       hash := uint32(0)
+       for i := len(sep) - 1; i >= 0; i-- {
+               hash = hash*primeRK + uint32(sep[i])
        }
        var pow, sq uint32 = 1, primeRK
        for i := len(sep); i > 0; i >>= 1 {
@@ -85,7 +101,8 @@ func Count(s, sep string) int {
                }
                return 0
        }
-       hashsep, pow := hashstr(sep)
+       // Rabin-Karp search
+       hashsep, pow := hashStr(sep)
        h := uint32(0)
        for i := 0; i < len(sep); i++ {
                h = h*primeRK + uint32(s[i])
@@ -139,8 +156,8 @@ func Index(s, sep string) int {
        case n > len(s):
                return -1
        }
-       // Hash sep.
-       hashsep, pow := hashstr(sep)
+       // Rabin-Karp search
+       hashsep, pow := hashStr(sep)
        var h uint32
        for i := 0; i < n; i++ {
                h = h*primeRK + uint32(s[i])
@@ -163,22 +180,41 @@ func Index(s, sep string) int {
 // LastIndex returns the index of the last instance of sep in s, or -1 if sep is not present in s.
 func LastIndex(s, sep string) int {
        n := len(sep)
-       if n == 0 {
+       switch {
+       case n == 0:
                return len(s)
-       }
-       c := sep[0]
-       if n == 1 {
+       case n == 1:
                // special case worth making fast
+               c := sep[0]
                for i := len(s) - 1; i >= 0; i-- {
                        if s[i] == c {
                                return i
                        }
                }
                return -1
+       case n == len(s):
+               if sep == s {
+                       return 0
+               }
+               return -1
+       case n > len(s):
+               return -1
+       }
+       // Rabin-Karp search from the end of the string
+       hashsep, pow := hashStrRev(sep)
+       last := len(s) - n
+       var h uint32
+       for i := len(s) - 1; i >= last; i-- {
+               h = h*primeRK + uint32(s[i])
+       }
+       if h == hashsep && s[last:] == sep {
+               return last
        }
-       // n > 1
-       for i := len(s) - n; i >= 0; i-- {
-               if s[i] == c && s[i:i+n] == sep {
+       for i := last - 1; i >= 0; i-- {
+               h *= primeRK
+               h += uint32(s[i])
+               h -= pow * uint32(s[i+n])
+               if h == hashsep && s[i:i+n] == sep {
                        return i
                }
        }
index 27c0314fe80010d1880cbd2058104b3925e74fe0..7bb81ef3ca14ff40bd96b7c030ee79f1619de878 100644 (file)
@@ -168,6 +168,15 @@ func BenchmarkIndex(b *testing.B) {
        }
 }
 
+func BenchmarkLastIndex(b *testing.B) {
+       if got := Index(benchmarkString, "v"); got != 17 {
+               b.Fatalf("wrong index: expected 17, got=%d", got)
+       }
+       for i := 0; i < b.N; i++ {
+               LastIndex(benchmarkString, "v")
+       }
+}
+
 func BenchmarkIndexByte(b *testing.B) {
        if got := IndexByte(benchmarkString, 'v'); got != 17 {
                b.Fatalf("wrong index: expected 17, got=%d", got)
@@ -1087,6 +1096,12 @@ func benchmarkIndexHard(b *testing.B, sep string) {
        }
 }
 
+func benchmarkLastIndexHard(b *testing.B, sep string) {
+       for i := 0; i < b.N; i++ {
+               LastIndex(benchInputHard, sep)
+       }
+}
+
 func benchmarkCountHard(b *testing.B, sep string) {
        for i := 0; i < b.N; i++ {
                Count(benchInputHard, sep)
@@ -1097,6 +1112,10 @@ func BenchmarkIndexHard1(b *testing.B) { benchmarkIndexHard(b, "<>") }
 func BenchmarkIndexHard2(b *testing.B) { benchmarkIndexHard(b, "</pre>") }
 func BenchmarkIndexHard3(b *testing.B) { benchmarkIndexHard(b, "<b>hello world</b>") }
 
+func BenchmarkLastIndexHard1(b *testing.B) { benchmarkLastIndexHard(b, "<>") }
+func BenchmarkLastIndexHard2(b *testing.B) { benchmarkLastIndexHard(b, "</pre>") }
+func BenchmarkLastIndexHard3(b *testing.B) { benchmarkLastIndexHard(b, "<b>hello world</b>") }
+
 func BenchmarkCountHard1(b *testing.B) { benchmarkCountHard(b, "<>") }
 func BenchmarkCountHard2(b *testing.B) { benchmarkCountHard(b, "</pre>") }
 func BenchmarkCountHard3(b *testing.B) { benchmarkCountHard(b, "<b>hello world</b>") }