]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: reduce memory usage when hijacking
authorJakob Ackermann <das7pad@outlook.com>
Mon, 3 Mar 2025 22:16:38 +0000 (22:16 +0000)
committerGopher Robot <gobot@golang.org>
Wed, 9 Apr 2025 20:49:52 +0000 (13:49 -0700)
Previously, Hijack allocated a new write buffer and the existing
connection write buffer used an extra 4KiB of memory until the handler
finished and the "conn" was garbage collected. Now, hijack re-uses the
existing write buffer and re-attaches it to the raw connection to avoid
referencing the net/http "conn" after returning.

After a handler that hijacked exited, the "conn" reference in
"connReader" will now be unset. This allows all of the "conn",
"response" and "Request" to get garbage collected.
Overall, this is reducing the memory usage by 43% or 6.7KiB per hijacked
connection (see BenchmarkServerHijackMemoryUsage in an earlier revision
of the CL).

CloseNotify will continue to work _before_ the handler has exited
(i.e. while the "conn" is still referenced in "connReader"). This aligns
with the documentation of CloseNotifier:
> After the Handler has returned, there is no guarantee that the channel
> receives a value.

goos: linux
goarch: amd64
pkg: net/http
cpu: Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz
               │   before    │             after              │
               │   sec/op    │    sec/op     vs base          │
ServerHijack-8   42.59µ ± 8%   39.47µ ± 16%  ~ (p=0.481 n=10)

               │    before    │                after                 │
               │     B/op     │     B/op      vs base                │
ServerHijack-8   16.12Ki ± 0%   12.06Ki ± 0%  -25.16% (p=0.000 n=10)

               │   before   │               after               │
               │ allocs/op  │ allocs/op   vs base               │
ServerHijack-8   51.00 ± 0%   49.00 ± 0%  -3.92% (p=0.000 n=10)

Change-Id: I20a37ee314ed0d47463a4657d712154e78e48138
GitHub-Last-Rev: 80f09dfa273035f53cdd72845e5c5fb129c3e230
GitHub-Pull-Request: golang/go#70756
Reviewed-on: https://go-review.googlesource.com/c/go/+/634855
Reviewed-by: Sean Liao <sean@liao.dev>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Auto-Submit: Sean Liao <sean@liao.dev>

src/net/http/serve_test.go
src/net/http/server.go

index b1d9f0b3e3eb792590f24f0f3973c64546dfdfb1..c603c201d591e0086dbac7483179015345df15b9 100644 (file)
@@ -6139,6 +6139,50 @@ func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
        <-done
 }
 
+// Test that the bufio.Reader returned by Hijack yields the entire body.
+func TestServerHijackGetsFullBody(t *testing.T) {
+       run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
+}
+func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
+       if runtime.GOOS == "plan9" {
+               t.Skip("skipping test; see https://golang.org/issue/18657")
+       }
+       done := make(chan struct{})
+       needle := strings.Repeat("x", 100*1024) // assume: larger than net/http bufio size
+       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               defer close(done)
+
+               conn, buf, err := w.(Hijacker).Hijack()
+               if err != nil {
+                       t.Error(err)
+                       return
+               }
+               defer conn.Close()
+
+               got := make([]byte, len(needle))
+               n, err := io.ReadFull(buf.Reader, got)
+               if n != len(needle) || string(got) != needle || err != nil {
+                       t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
+               }
+       })).ts
+
+       cn, err := net.Dial("tcp", ts.Listener.Addr().String())
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer cn.Close()
+       buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
+       buf = append(buf, []byte(needle)...)
+       if _, err := cn.Write(buf); err != nil {
+               t.Fatal(err)
+       }
+
+       if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
+               t.Fatal(err)
+       }
+       <-done
+}
+
 // Like TestServerHijackGetsBackgroundByte above but sending a
 // immediate 1MB of data to the server to fill up the server's 4KB
 // buffer.
index be25e9a450be66c0350651c542b062c3bd0da619..49a9d30207fa798e9715908e240dfcd5903ad4d8 100644 (file)
@@ -324,12 +324,14 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
        rwc = c.rwc
        rwc.SetDeadline(time.Time{})
 
-       buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc))
        if c.r.hasByte {
                if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil {
                        return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err)
                }
        }
+       c.bufw.Reset(rwc)
+       buf = bufio.NewReadWriter(c.bufr, c.bufw)
+
        c.setState(rwc, StateHijacked, runHooks)
        return
 }
@@ -652,10 +654,13 @@ type readResult struct {
 // read sizes) with support for selectively keeping an io.Reader.Read
 // call blocked in a background goroutine to wait for activity and
 // trigger a CloseNotifier channel.
+// After a Handler has hijacked the conn and exited, connReader behaves like a
+// proxy for the net.Conn and the aforementioned behavior is bypassed.
 type connReader struct {
-       conn *conn
+       rwc net.Conn // rwc is the underlying network connection.
 
        mu      sync.Mutex // guards following
+       conn    *conn      // conn is nil after handler exit.
        hasByte bool
        byteBuf [1]byte
        cond    *sync.Cond
@@ -673,6 +678,12 @@ func (cr *connReader) lock() {
 
 func (cr *connReader) unlock() { cr.mu.Unlock() }
 
+func (cr *connReader) releaseConn() {
+       cr.lock()
+       defer cr.unlock()
+       cr.conn = nil
+}
+
 func (cr *connReader) startBackgroundRead() {
        cr.lock()
        defer cr.unlock()
@@ -683,12 +694,12 @@ func (cr *connReader) startBackgroundRead() {
                return
        }
        cr.inRead = true
-       cr.conn.rwc.SetReadDeadline(time.Time{})
+       cr.rwc.SetReadDeadline(time.Time{})
        go cr.backgroundRead()
 }
 
 func (cr *connReader) backgroundRead() {
-       n, err := cr.conn.rwc.Read(cr.byteBuf[:])
+       n, err := cr.rwc.Read(cr.byteBuf[:])
        cr.lock()
        if n == 1 {
                cr.hasByte = true
@@ -719,7 +730,7 @@ func (cr *connReader) backgroundRead() {
                // Ignore this error. It's the expected error from
                // another goroutine calling abortPendingRead.
        } else if err != nil {
-               cr.handleReadError(err)
+               cr.handleReadErrorLocked(err)
        }
        cr.aborted = false
        cr.inRead = false
@@ -734,18 +745,18 @@ func (cr *connReader) abortPendingRead() {
                return
        }
        cr.aborted = true
-       cr.conn.rwc.SetReadDeadline(aLongTimeAgo)
+       cr.rwc.SetReadDeadline(aLongTimeAgo)
        for cr.inRead {
                cr.cond.Wait()
        }
-       cr.conn.rwc.SetReadDeadline(time.Time{})
+       cr.rwc.SetReadDeadline(time.Time{})
 }
 
 func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
 func (cr *connReader) setInfiniteReadLimit()     { cr.remain = maxInt64 }
 func (cr *connReader) hitReadLimit() bool        { return cr.remain <= 0 }
 
-// handleReadError is called whenever a Read from the client returns a
+// handleReadErrorLocked is called whenever a Read from the client returns a
 // non-nil error.
 //
 // The provided non-nil err is almost always io.EOF or a "use of
@@ -754,14 +765,12 @@ func (cr *connReader) hitReadLimit() bool        { return cr.remain <= 0 }
 // development. Any error means the connection is dead and we should
 // down its context.
 //
-// It may be called from multiple goroutines.
-func (cr *connReader) handleReadError(_ error) {
+// The caller must hold connReader.mu.
+func (cr *connReader) handleReadErrorLocked(_ error) {
+       if cr.conn == nil {
+               return
+       }
        cr.conn.cancelCtx()
-       cr.closeNotify()
-}
-
-// may be called from multiple goroutines.
-func (cr *connReader) closeNotify() {
        if res := cr.conn.curReq.Load(); res != nil {
                res.closeNotify()
        }
@@ -769,9 +778,14 @@ func (cr *connReader) closeNotify() {
 
 func (cr *connReader) Read(p []byte) (n int, err error) {
        cr.lock()
+       if cr.conn == nil {
+               cr.unlock()
+               return cr.rwc.Read(p)
+       }
        if cr.inRead {
+               hijacked := cr.conn.hijacked()
                cr.unlock()
-               if cr.conn.hijacked() {
+               if hijacked {
                        panic("invalid Body.Read call. After hijacked, the original Request must not be used")
                }
                panic("invalid concurrent Body.Read call")
@@ -795,12 +809,12 @@ func (cr *connReader) Read(p []byte) (n int, err error) {
        }
        cr.inRead = true
        cr.unlock()
-       n, err = cr.conn.rwc.Read(p)
+       n, err = cr.rwc.Read(p)
 
        cr.lock()
        cr.inRead = false
        if err != nil {
-               cr.handleReadError(err)
+               cr.handleReadErrorLocked(err)
        }
        cr.remain -= int64(n)
        cr.unlock()
@@ -1986,7 +2000,7 @@ func (c *conn) serve(ctx context.Context) {
        c.cancelCtx = cancelCtx
        defer cancelCtx()
 
-       c.r = &connReader{conn: c}
+       c.r = &connReader{conn: c, rwc: c.rwc}
        c.bufr = newBufioReader(c.r)
        c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
 
@@ -2083,6 +2097,7 @@ func (c *conn) serve(ctx context.Context) {
                inFlightResponse = nil
                w.cancelCtx()
                if c.hijacked() {
+                       c.r.releaseConn()
                        return
                }
                w.finishRequest()