From b2a3b54b9520ce869d79ac8bce836a540ba45d09 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 13 Jan 2017 21:43:56 +0000 Subject: [PATCH] net/http: make sure Hijack's bufio.Reader includes pre-read background byte Previously, if the Hijack called stopped the background read call which read a byte, that byte was sitting in memory, buffered, ready to be Read by Hijack's returned bufio.Reader, but it wasn't yet in the bufio.Reader's buffer itself, so bufio.Reader.Buffered() reported 1 byte fewer. This matters for callers who wanted to stitch together any buffered data (with bufio.Reader.Peek(bufio.Reader.Buffered())) with Hijack's returned net.Conn. Otherwise there was no way for callers to know a byte was read. Change-Id: Id7cb0a0a33fe2f33d79250e13dbaa9c0f7abba13 Reviewed-on: https://go-review.googlesource.com/35232 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Joe Tsai --- src/net/http/serve_test.go | 67 ++++++++++++++++++++++++++++++++++++++ src/net/http/server.go | 10 +++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 072da2552b..22188ab483 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -5173,3 +5173,70 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { } wg.Wait() } + +// Test that the bufio.Reader returned by Hijack includes any buffered +// byte (from the Server's backgroundRead) in its buffer. We want the +// Handler code to be able to tell that a byte is available via +// bufio.Reader.Buffered(), without resorting to Reading it +// (potentially blocking) to get at it. +func TestServerHijackGetsBackgroundByte(t *testing.T) { + setParallel(t) + defer afterTest(t) + done := make(chan struct{}) + inHandler := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + defer close(done) + + // Tell the client to send more data after the GET request. + inHandler <- true + + // Wait until the HTTP server sees the extra data + // after the GET request. The HTTP server fires the + // close notifier here, assuming it's a pipelined + // request, as documented. + select { + case <-w.(CloseNotifier).CloseNotify(): + case <-time.After(5 * time.Second): + t.Error("timeout") + return + } + + conn, buf, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + n := buf.Reader.Buffered() + if n != 1 { + t.Errorf("buffered data = %d; want 1", n) + } + peek, err := buf.Reader.Peek(3) + if string(peek) != "foo" || err != nil { + t.Errorf("Peek = %q, %v; want foo, nil", peek, err) + } + })) + defer ts.Close() + + cn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cn.Close() + if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil { + t.Fatal(err) + } + <-inHandler + if _, err := cn.Write([]byte("foo")); err != nil { + t.Fatal(err) + } + + if err := cn.(*net.TCPConn).CloseWrite(); err != nil { + t.Fatal(err) + } + select { + case <-done: + case <-time.After(2 * time.Second): + t.Error("timeout") + } +} diff --git a/src/net/http/server.go b/src/net/http/server.go index 96236489bd..df70a15193 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -164,7 +164,7 @@ type Flusher interface { // should always test for this ability at runtime. type Hijacker interface { // Hijack lets the caller take over the connection. - // After a call to Hijack(), the HTTP server library + // After a call to Hijack the HTTP server library // will not do anything else with the connection. // // It becomes the caller's responsibility to manage @@ -174,6 +174,9 @@ type Hijacker interface { // already set, depending on the configuration of the // Server. It is the caller's responsibility to set // or clear those deadlines as needed. + // + // The returned bufio.Reader may contain unprocessed buffered + // data from the client. Hijack() (net.Conn, *bufio.ReadWriter, error) } @@ -293,6 +296,11 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { rwc.SetDeadline(time.Time{}) buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) + if c.r.hasByte { + if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil { + return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err) + } + } c.setState(rwc, StateHijacked) return } -- 2.48.1