]> Cypherpunks repositories - gostls13.git/commitdiff
bytes: add optimized countByte for amd64
authorJosselin Costanzi <josselin@costanzi.fr>
Sun, 19 Mar 2017 11:18:08 +0000 (12:18 +0100)
committerKeith Randall <khr@golang.org>
Tue, 21 Mar 2017 20:25:17 +0000 (20:25 +0000)
Use SSE/AVX2 when counting a single byte.
Inspired from runtime indexbyte implementation.

Benchmark against previous implementation, where
1 byte in every 8 is the one we are looking for:

* On a machine without AVX2
name               old time/op   new time/op     delta
CountSingle/10-4    61.8ns ±10%     15.6ns ±11%    -74.83%  (p=0.000 n=10+10)
CountSingle/32-4     100ns ± 4%       17ns ±10%    -82.54%  (p=0.000 n=10+9)
CountSingle/4K-4    9.66µs ± 3%     0.37µs ± 6%    -96.21%  (p=0.000 n=10+10)
CountSingle/4M-4    11.0ms ± 6%      0.4ms ± 4%    -96.04%  (p=0.000 n=10+10)
CountSingle/64M-4    194ms ± 8%        8ms ± 2%    -95.64%  (p=0.000 n=10+10)

name               old speed     new speed       delta
CountSingle/10-4   162MB/s ±10%    645MB/s ±10%   +297.00%  (p=0.000 n=10+10)
CountSingle/32-4   321MB/s ± 5%   1844MB/s ± 9%   +474.79%  (p=0.000 n=10+9)
CountSingle/4K-4   424MB/s ± 3%  11169MB/s ± 6%  +2533.10%  (p=0.000 n=10+10)
CountSingle/4M-4   381MB/s ± 7%   9609MB/s ± 4%  +2421.88%  (p=0.000 n=10+10)
CountSingle/64M-4  346MB/s ± 7%   7924MB/s ± 2%  +2188.78%  (p=0.000 n=10+10)

* On a machine with AVX2
name               old time/op   new time/op     delta
CountSingle/10-8    37.1ns ± 3%      8.2ns ± 1%    -77.80%  (p=0.000 n=10+10)
CountSingle/32-8    66.1ns ± 3%      9.8ns ± 2%    -85.23%  (p=0.000 n=10+10)
CountSingle/4K-8    7.36µs ± 3%     0.11µs ± 1%    -98.54%  (p=0.000 n=10+10)
CountSingle/4M-8    7.46ms ± 2%     0.15ms ± 2%    -97.95%  (p=0.000 n=10+9)
CountSingle/64M-8    124ms ± 2%        6ms ± 4%    -95.09%  (p=0.000 n=10+10)

name               old speed     new speed       delta
CountSingle/10-8   269MB/s ± 3%   1213MB/s ± 1%   +350.32%  (p=0.000 n=10+10)
CountSingle/32-8   484MB/s ± 4%   3277MB/s ± 2%   +576.66%  (p=0.000 n=10+10)
CountSingle/4K-8   556MB/s ± 3%  37933MB/s ± 1%  +6718.36%  (p=0.000 n=10+10)
CountSingle/4M-8   562MB/s ± 2%  27444MB/s ± 3%  +4783.43%  (p=0.000 n=10+9)
CountSingle/64M-8  543MB/s ± 2%  11054MB/s ± 3%  +1935.81%  (p=0.000 n=10+10)

Fixes #19411

Change-Id: Ieaf20b1fabccabe767c55c66e242e86f3617f883
Reviewed-on: https://go-review.googlesource.com/38258
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Keith Randall <khr@golang.org>
src/bytes/bytes.go
src/bytes/bytes_amd64.go
src/bytes/bytes_amd64.s [new file with mode: 0644]
src/bytes/bytes_generic.go
src/bytes/bytes_s390x.go
src/bytes/bytes_test.go
src/bytes/export_test.go
src/cmd/vet/all/whitelist/amd64.txt
src/runtime/asm_amd64.s
src/runtime/runtime2.go

index 029609afba2e777508e9b2be28bca1622e2de4b2..f461d2b3ce7d1027563ba80c4b24bed1072f67cb 100644 (file)
@@ -46,9 +46,8 @@ func explode(s []byte, n int) [][]byte {
        return a[0:na]
 }
 
-// Count counts the number of non-overlapping instances of sep in s.
-// If sep is an empty slice, Count returns 1 + the number of Unicode code points in s.
-func Count(s, sep []byte) int {
+// countGeneric actualy implements Count
+func countGeneric(s, sep []byte) int {
        n := 0
        // special case
        if len(sep) == 0 {
index 6affff6334da591056dbe3b5670dc5d92adfc233..5b42f272d09127f2639289cab1d14e1b3830f617 100644 (file)
@@ -10,6 +10,7 @@ package bytes
 // indexShortStr requires 2 <= len(c) <= shortStringLen
 func indexShortStr(s, c []byte) int // ../runtime/asm_$GOARCH.s
 func supportAVX2() bool             // ../runtime/asm_$GOARCH.s
+func supportPOPCNT() bool           // ../runtime/asm_$GOARCH.s
 
 var shortStringLen int
 
@@ -94,6 +95,18 @@ func Index(s, sep []byte) int {
        return -1
 }
 
+// Special case for when we must count occurences of a single byte.
+func countByte(s []byte, c byte) int
+
+// Count counts the number of non-overlapping instances of sep in s.
+// If sep is an empty slice, Count returns 1 + the number of Unicode code points in s.
+func Count(s, sep []byte) int {
+       if len(sep) == 1 && supportPOPCNT() {
+               return countByte(s, sep[0])
+       }
+       return countGeneric(s, sep)
+}
+
 // primeRK is the prime base used in Rabin-Karp algorithm.
 const primeRK = 16777619
 
diff --git a/src/bytes/bytes_amd64.s b/src/bytes/bytes_amd64.s
new file mode 100644 (file)
index 0000000..f4cbadf
--- /dev/null
@@ -0,0 +1,183 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+// We use:
+//   SI: data
+//   BX: data len
+//   AL: byte sought
+// This require the POPCNT instruction
+TEXT ·countByte(SB),NOSPLIT,$0-40
+       MOVQ s+0(FP), SI
+       MOVQ s_len+8(FP), BX
+       MOVB c+24(FP), AL
+
+       // Shuffle X0 around so that each byte contains
+       // the character we're looking for.
+       MOVD AX, X0
+       PUNPCKLBW X0, X0
+       PUNPCKLBW X0, X0
+       PSHUFL $0, X0, X0
+
+       CMPQ BX, $16
+       JLT small
+
+       MOVQ $0, R12 // Accumulator
+
+       MOVQ SI, DI
+
+       CMPQ BX, $32
+       JA avx2
+sse:
+       LEAQ    -16(SI)(BX*1), AX       // AX = address of last 16 bytes
+       JMP     sseloopentry
+
+sseloop:
+       // Move the next 16-byte chunk of the data into X1.
+       MOVOU   (DI), X1
+       // Compare bytes in X0 to X1.
+       PCMPEQB X0, X1
+       // Take the top bit of each byte in X1 and put the result in DX.
+       PMOVMSKB X1, DX
+       // Count number of matching bytes
+       POPCNTL DX, DX
+       // Accumulate into R12
+       ADDQ DX, R12
+       // Advance to next block.
+       ADDQ    $16, DI
+sseloopentry:
+       CMPQ    DI, AX
+       JBE     sseloop
+
+       // Get the number of bytes to consider in the last 16 bytes
+       ANDQ $15, BX
+       JZ end
+
+       // Create mask to ignore overlap between previous 16 byte block
+       // and the next.
+       MOVQ $16,CX
+       SUBQ BX, CX
+       MOVQ $0xFFFF, R10
+       SARQ CL, R10
+       SALQ CL, R10
+
+       // Process the last 16-byte chunk. This chunk may overlap with the
+       // chunks we've already searched so we need to mask part of it.
+       MOVOU   (AX), X1
+       PCMPEQB X0, X1
+       PMOVMSKB X1, DX
+       // Apply mask
+       ANDQ R10, DX
+       POPCNTL DX, DX
+       ADDQ DX, R12
+end:
+       MOVQ R12, ret+32(FP)
+       RET
+
+// handle for lengths < 16
+small:
+       TESTQ   BX, BX
+       JEQ     endzero
+
+       // Check if we'll load across a page boundary.
+       LEAQ    16(SI), AX
+       TESTW   $0xff0, AX
+       JEQ     endofpage
+
+       // We must ignore high bytes as they aren't part of our slice.
+       // Create mask.
+       MOVB BX, CX
+       MOVQ $1, R10
+       SALQ CL, R10
+       SUBQ $1, R10
+
+       // Load data
+       MOVOU   (SI), X1
+       // Compare target byte with each byte in data.
+       PCMPEQB X0, X1
+       // Move result bits to integer register.
+       PMOVMSKB X1, DX
+       // Apply mask
+       ANDQ R10, DX
+       POPCNTL DX, DX
+       // Directly return DX, we don't need to accumulate
+       // since we have <16 bytes.
+       MOVQ    DX, ret+32(FP)
+       RET
+endzero:
+       MOVQ $0, ret+32(FP)
+       RET
+
+endofpage:
+       // We must ignore low bytes as they aren't part of our slice.
+       MOVQ $16,CX
+       SUBQ BX, CX
+       MOVQ $0xFFFF, R10
+       SARQ CL, R10
+       SALQ CL, R10
+
+       // Load data into the high end of X1.
+       MOVOU   -16(SI)(BX*1), X1
+       // Compare target byte with each byte in data.
+       PCMPEQB X0, X1
+       // Move result bits to integer register.
+       PMOVMSKB X1, DX
+       // Apply mask
+       ANDQ R10, DX
+       // Directly return DX, we don't need to accumulate
+       // since we have <16 bytes.
+       POPCNTL DX, DX
+       MOVQ    DX, ret+32(FP)
+       RET
+
+avx2:
+       CMPB   runtime·support_avx2(SB), $1
+       JNE sse
+       MOVD AX, X0
+       LEAQ -32(SI)(BX*1), R11
+       VPBROADCASTB  X0, Y1
+avx2_loop:
+       VMOVDQU (DI), Y2
+       VPCMPEQB Y1, Y2, Y3
+       VPMOVMSKB Y3, DX
+       POPCNTL DX, DX
+       ADDQ DX, R12
+       ADDQ $32, DI
+       CMPQ DI, R11
+       JLE avx2_loop
+
+       // If last block is already processed,
+       // skip to the end.
+       CMPQ DI, R11
+       JEQ endavx
+
+       // Load address of the last 32 bytes.
+       // There is an overlap with the previous block.
+       MOVQ R11, DI
+       VMOVDQU (DI), Y2
+       VPCMPEQB Y1, Y2, Y3
+       VPMOVMSKB Y3, DX
+       // Exit AVX mode.
+       VZEROUPPER
+
+       // Create mask to ignore overlap between previous 32 byte block
+       // and the next.
+       ANDQ $31, BX
+       MOVQ $32,CX
+       SUBQ BX, CX
+       MOVQ $0xFFFFFFFF, R10
+       SARQ CL, R10
+       SALQ CL, R10
+       // Apply mask
+       ANDQ R10, DX
+       POPCNTL DX, DX
+       ADDQ DX, R12
+       MOVQ R12, ret+32(FP)
+       RET
+endavx:
+       // Exit AVX mode.
+       VZEROUPPER
+       MOVQ R12, ret+32(FP)
+       RET
index 06c9e1f26c1da8675148db5e8a2e89591122f71e..98454bc121b8b3971ff70cfd420be7ab5c82bfbc 100644 (file)
@@ -39,3 +39,9 @@ func Index(s, sep []byte) int {
        }
        return -1
 }
+
+// Count counts the number of non-overlapping instances of sep in s.
+// If sep is an empty slice, Count returns 1 + the number of Unicode code points in s.
+func Count(s, sep []byte) int {
+       return countGeneric(s, sep)
+}
index 988c6034aa54350d7f809d5d9a49ecff759b67f8..68b57301fe8e7b7b0bf248077a2f715c81f833a5 100644 (file)
@@ -97,6 +97,12 @@ func Index(s, sep []byte) int {
        return -1
 }
 
+// Count counts the number of non-overlapping instances of sep in s.
+// If sep is an empty slice, Count returns 1 + the number of Unicode code points in s.
+func Count(s, sep []byte) int {
+       return countGeneric(s, sep)
+}
+
 // primeRK is the prime base used in Rabin-Karp algorithm.
 const primeRK = 16777619
 
index dd8bdf2b04a203e834880df3ab05300f785ab1f4..ca0cdbb7c9f02bc78a092cf3cb14e89cb93870cd 100644 (file)
@@ -396,6 +396,79 @@ func TestIndexRune(t *testing.T) {
        }
 }
 
+// test count of a single byte across page offsets
+func TestCountByte(t *testing.T) {
+       b := make([]byte, 5015) // bigger than a page
+       windows := []int{1, 2, 3, 4, 15, 16, 17, 31, 32, 33, 63, 64, 65, 128}
+       testCountWindow := func(i, window int) {
+               for j := 0; j < window; j++ {
+                       b[i+j] = byte(100)
+                       p := Count(b[i:i+window], []byte{100})
+                       if p != j+1 {
+                               t.Errorf("TestCountByte.Count(%q, 100) = %d", b[i:i+window], p)
+                       }
+                       pGeneric := CountGeneric(b[i:i+window], []byte{100})
+                       if pGeneric != j+1 {
+                               t.Errorf("TestCountByte.CountGeneric(%q, 100) = %d", b[i:i+window], p)
+                       }
+               }
+       }
+
+       maxWnd := windows[len(windows)-1]
+
+       for i := 0; i <= 2*maxWnd; i++ {
+               for _, window := range windows {
+                       if window > len(b[i:]) {
+                               window = len(b[i:])
+                       }
+                       testCountWindow(i, window)
+                       for j := 0; j < window; j++ {
+                               b[i+j] = byte(0)
+                       }
+               }
+       }
+       for i := 4096 - (maxWnd + 1); i < len(b); i++ {
+               for _, window := range windows {
+                       if window > len(b[i:]) {
+                               window = len(b[i:])
+                       }
+                       testCountWindow(i, window)
+                       for j := 0; j < window; j++ {
+                               b[i+j] = byte(0)
+                       }
+               }
+       }
+}
+
+// Make sure we don't count bytes outside our window
+func TestCountByteNoMatch(t *testing.T) {
+       b := make([]byte, 5015)
+       windows := []int{1, 2, 3, 4, 15, 16, 17, 31, 32, 33, 63, 64, 65, 128}
+       for i := 0; i <= len(b); i++ {
+               for _, window := range windows {
+                       if window > len(b[i:]) {
+                               window = len(b[i:])
+                       }
+                       // Fill the window with non-match
+                       for j := 0; j < window; j++ {
+                               b[i+j] = byte(100)
+                       }
+                       // Try to find something that doesn't exist
+                       p := Count(b[i:i+window], []byte{0})
+                       if p != 0 {
+                               t.Errorf("TestCountByteNoMatch(%q, 0) = %d", b[i:i+window], p)
+                       }
+                       pGeneric := CountGeneric(b[i:i+window], []byte{0})
+                       if pGeneric != 0 {
+                               t.Errorf("TestCountByteNoMatch.CountGeneric(%q, 100) = %d", b[i:i+window], p)
+                       }
+                       for j := 0; j < window; j++ {
+                               b[i+j] = byte(0)
+                       }
+               }
+       }
+}
+
 var bmbuf []byte
 
 func valName(x int) string {
@@ -589,6 +662,26 @@ func BenchmarkCountEasy(b *testing.B) {
        })
 }
 
+func BenchmarkCountSingle(b *testing.B) {
+       benchBytes(b, indexSizes, func(b *testing.B, n int) {
+               buf := bmbuf[0:n]
+               step := 8
+               for i := 0; i < len(buf); i += step {
+                       buf[i] = 1
+               }
+               expect := (len(buf) + (step - 1)) / step
+               for i := 0; i < b.N; i++ {
+                       j := Count(buf, []byte{1})
+                       if j != expect {
+                               b.Fatal("bad count", j, expect)
+                       }
+               }
+               for i := 0; i < len(buf); i++ {
+                       buf[i] = 0
+               }
+       })
+}
+
 type ExplodeTest struct {
        s string
        n int
index f61523e60bbb3fd209594912ebcaa76cde9acd20..823c8b09eef8fc048a7b93bccbbdcfbc501883e2 100644 (file)
@@ -7,3 +7,4 @@ package bytes
 // Export func for testing
 var IndexBytePortable = indexBytePortable
 var EqualPortable = equalPortable
+var CountGeneric = countGeneric
index df4ec84195a5afa774d811b74c96b890e57e996d..92a693af837c1cf940c9d15e9eb3c030f0f69cfd 100644 (file)
@@ -17,6 +17,7 @@ runtime/asm_amd64.s: [GOARCH] cannot check cross-package assembly function: Comp
 runtime/asm_amd64.s: [amd64] cannot check cross-package assembly function: indexShortStr is in package bytes
 runtime/asm_amd64.s: [amd64] cannot check cross-package assembly function: supportAVX2 is in package strings
 runtime/asm_amd64.s: [amd64] cannot check cross-package assembly function: supportAVX2 is in package bytes
+runtime/asm_amd64.s: [amd64] cannot check cross-package assembly function: supportPOPCNT is in package bytes
 
 // Intentionally missing declarations. These are special assembly routines.
 // Some are jumped into from other routines, with values in specific registers.
index c6ff8379e646b08be1d8d9ba460f3c7faad967eb..c0a5048edab16883e9887bece8c32b4ffbb4c877 100644 (file)
@@ -91,8 +91,13 @@ testbmi1:
 testbmi2:
        MOVB    $0, runtime·support_bmi2(SB)
        TESTL   $(1<<8), runtime·cpuid_ebx7(SB) // check for BMI2 bit
-       JEQ     nocpuinfo
+       JEQ     testpopcnt
        MOVB    $1, runtime·support_bmi2(SB)
+testpopcnt:
+       MOVB    $0, runtime·support_popcnt(SB)
+       TESTL   $(1<<23), runtime·cpuid_ecx(SB) // check for POPCNT bit
+       JEQ     nocpuinfo
+       MOVB    $1, runtime·support_popcnt(SB)
 nocpuinfo:     
        
        // if there is an _cgo_init, call it.
@@ -1697,6 +1702,11 @@ TEXT bytes·supportAVX2(SB),NOSPLIT,$0-1
        MOVB AX, ret+0(FP)
        RET
 
+TEXT bytes·supportPOPCNT(SB),NOSPLIT,$0-1
+       MOVBLZX runtime·support_popcnt(SB), AX
+       MOVB AX, ret+0(FP)
+       RET
+
 TEXT strings·indexShortStr(SB),NOSPLIT,$0-40
        MOVQ s+0(FP), DI
        // We want len in DX and AX, because PCMPESTRI implicitly consumes them
index 8b6bddf456eb0f2aed52ee6591c7c0854eb742ee..50e39acaa57e2b8fc24f76468163c53e23b7c1f2 100644 (file)
@@ -728,6 +728,7 @@ var (
        support_avx2      bool
        support_bmi1      bool
        support_bmi2      bool
+       support_popcnt    bool
 
        goarm                uint8 // set by cmd/link on arm systems
        framepointer_enabled bool  // set by cmd/link