]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httputil: add hook for managing io.Copy buffers per request
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 28 Apr 2015 01:10:20 +0000 (18:10 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 28 Oct 2015 05:53:28 +0000 (05:53 +0000)
Adds ReverseProxy.BufferPool for users with sensitive allocation
requirements. Permits avoiding 32 KB of io.Copy garbage per request.

Fixes #10277

Change-Id: I5dfd58fa70a363ead4be56405e507df90d871719
Reviewed-on: https://go-review.googlesource.com/9399
Reviewed-by: Keith Randall <khr@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/httputil/reverseproxy.go
src/net/http/httputil/reverseproxy_test.go

index 95a99ddb9dd0b19cbc24b4884fde36f78808348a..4dba352a4fa5110f94c2e3a012533fc2784e639f 100644 (file)
@@ -46,6 +46,18 @@ type ReverseProxy struct {
        // If nil, logging goes to os.Stderr via the log package's
        // standard logger.
        ErrorLog *log.Logger
+
+       // BufferPool optionally specifies a buffer pool to
+       // get byte slices for use by io.CopyBuffer when
+       // copying HTTP response bodies.
+       BufferPool BufferPool
+}
+
+// A BufferPool is an interface for getting and returning temporary
+// byte slices for use by io.CopyBuffer.
+type BufferPool interface {
+       Get() []byte
+       Put([]byte)
 }
 
 func singleJoiningSlash(a, b string) string {
@@ -245,7 +257,14 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
                }
        }
 
-       io.Copy(dst, src)
+       var buf []byte
+       if p.BufferPool != nil {
+               buf = p.BufferPool.Get()
+       }
+       io.CopyBuffer(dst, src, buf)
+       if p.BufferPool != nil {
+               p.BufferPool.Put(buf)
+       }
 }
 
 func (p *ReverseProxy) logf(format string, args ...interface{}) {
index 1d309614e2b64019db287ac6e351a4c4c76338da..5f6fc56e07a5b926faaebc88d11e3b0f5e6384fc 100644 (file)
@@ -8,13 +8,16 @@ package httputil
 
 import (
        "bufio"
+       "io"
        "io/ioutil"
        "log"
        "net/http"
        "net/http/httptest"
        "net/url"
        "reflect"
+       "strconv"
        "strings"
+       "sync"
        "testing"
        "time"
 )
@@ -316,3 +319,68 @@ func TestNilBody(t *testing.T) {
                t.Errorf("Got %q; want %q", slurp, "hi")
        }
 }
+
+type bufferPool struct {
+       get func() []byte
+       put func([]byte)
+}
+
+func (bp bufferPool) Get() []byte  { return bp.get() }
+func (bp bufferPool) Put(v []byte) { bp.put(v) }
+
+func TestReverseProxyGetPutBuffer(t *testing.T) {
+       const msg = "hi"
+       backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               io.WriteString(w, msg)
+       }))
+       defer backend.Close()
+
+       backendURL, err := url.Parse(backend.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       var (
+               mu  sync.Mutex
+               log []string
+       )
+       addLog := func(event string) {
+               mu.Lock()
+               defer mu.Unlock()
+               log = append(log, event)
+       }
+       rp := NewSingleHostReverseProxy(backendURL)
+       const size = 1234
+       rp.BufferPool = bufferPool{
+               get: func() []byte {
+                       addLog("getBuf")
+                       return make([]byte, size)
+               },
+               put: func(p []byte) {
+                       addLog("putBuf-" + strconv.Itoa(len(p)))
+               },
+       }
+       frontend := httptest.NewServer(rp)
+       defer frontend.Close()
+
+       req, _ := http.NewRequest("GET", frontend.URL, nil)
+       req.Close = true
+       res, err := http.DefaultClient.Do(req)
+       if err != nil {
+               t.Fatalf("Get: %v", err)
+       }
+       slurp, err := ioutil.ReadAll(res.Body)
+       res.Body.Close()
+       if err != nil {
+               t.Fatalf("reading body: %v", err)
+       }
+       if string(slurp) != msg {
+               t.Errorf("msg = %q; want %q", slurp, msg)
+       }
+       wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
+       mu.Lock()
+       defer mu.Unlock()
+       if !reflect.DeepEqual(log, wantLog) {
+               t.Errorf("Log events = %q; want %q", log, wantLog)
+       }
+}