]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: don't reuse a server connection after any Write errors
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 15 Oct 2014 15:51:12 +0000 (17:51 +0200)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 15 Oct 2014 15:51:12 +0000 (17:51 +0200)
Fixes #8534

LGTM=adg
R=adg
CC=golang-codereviews
https://golang.org/cl/149340044

src/net/http/serve_test.go
src/net/http/server.go

index 702bffdc13624512809d9b4e86c3fc8c4dde65e7..bb44ac8537ab0b1fd51362376c041d98cd8be018 100644 (file)
@@ -2659,6 +2659,103 @@ func TestCloseWrite(t *testing.T) {
        }
 }
 
+// This verifies that a handler can Flush and then Hijack.
+//
+// An similar test crashed once during development, but it was only
+// testing this tangentially and temporarily until another TODO was
+// fixed.
+//
+// So add an explicit test for this.
+func TestServerFlushAndHijack(t *testing.T) {
+       defer afterTest(t)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               io.WriteString(w, "Hello, ")
+               w.(Flusher).Flush()
+               conn, buf, _ := w.(Hijacker).Hijack()
+               buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
+               if err := buf.Flush(); err != nil {
+                       t.Error(err)
+               }
+               if err := conn.Close(); err != nil {
+                       t.Error(err)
+               }
+       }))
+       defer ts.Close()
+       res, err := Get(ts.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer res.Body.Close()
+       all, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if want := "Hello, world!"; string(all) != want {
+               t.Errorf("Got %q; want %q", all, want)
+       }
+}
+
+// golang.org/issue/8534 -- the Server shouldn't reuse a connection
+// for keep-alive after it's seen any Write error (e.g. a timeout) on
+// that net.Conn.
+//
+// To test, verify we don't timeout or see fewer unique client
+// addresses (== unique connections) than requests.
+func TestServerKeepAliveAfterWriteError(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping in -short mode")
+       }
+       defer afterTest(t)
+       const numReq = 3
+       addrc := make(chan string, numReq)
+       ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               addrc <- r.RemoteAddr
+               time.Sleep(500 * time.Millisecond)
+               w.(Flusher).Flush()
+       }))
+       ts.Config.WriteTimeout = 250 * time.Millisecond
+       ts.Start()
+       defer ts.Close()
+
+       errc := make(chan error, numReq)
+       go func() {
+               defer close(errc)
+               for i := 0; i < numReq; i++ {
+                       res, err := Get(ts.URL)
+                       if res != nil {
+                               res.Body.Close()
+                       }
+                       errc <- err
+               }
+       }()
+
+       timeout := time.NewTimer(numReq * 2 * time.Second) // 4x overkill
+       defer timeout.Stop()
+       addrSeen := map[string]bool{}
+       numOkay := 0
+       for {
+               select {
+               case v := <-addrc:
+                       addrSeen[v] = true
+               case err, ok := <-errc:
+                       if !ok {
+                               if len(addrSeen) != numReq {
+                                       t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
+                               }
+                               if numOkay != 0 {
+                                       t.Errorf("got %d successful client requests; want 0", numOkay)
+                               }
+                               return
+                       }
+                       if err == nil {
+                               numOkay++
+                       }
+               case <-timeout.C:
+                       t.Fatal("timeout waiting for requests to complete")
+               }
+       }
+}
+
 func BenchmarkClientServer(b *testing.B) {
        b.ReportAllocs()
        b.StopTimer()
index b5959f7321f464ad74eb97e35e219c68a5974e6c..008d5aa7a748fbc9f51ea4d999d6d3e7c363df1a 100644 (file)
@@ -114,6 +114,8 @@ type conn struct {
        remoteAddr string               // network address of remote side
        server     *Server              // the Server on which the connection arrived
        rwc        net.Conn             // i/o connection
+       w          io.Writer            // checkConnErrorWriter's copy of wrc, not zeroed on Hijack
+       werr       error                // any errors writing to w
        sr         liveSwitchReader     // where the LimitReader reads from; usually the rwc
        lr         *io.LimitedReader    // io.LimitReader(sr)
        buf        *bufio.ReadWriter    // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
@@ -432,13 +434,14 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
        c.remoteAddr = rwc.RemoteAddr().String()
        c.server = srv
        c.rwc = rwc
+       c.w = rwc
        if debugServerConnections {
                c.rwc = newLoggingConn("server", c.rwc)
        }
        c.sr = liveSwitchReader{r: c.rwc}
        c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
        br := newBufioReader(c.lr)
-       bw := newBufioWriterSize(c.rwc, 4<<10)
+       bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
        c.buf = bufio.NewReadWriter(br, bw)
        return c, nil
 }
@@ -956,8 +959,10 @@ func (w *response) bodyAllowed() bool {
 // 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes
 // 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type)
 //    and which writes the chunk headers, if needed.
-// 4. conn.buf, a bufio.Writer of default (4kB) bytes
-// 5. the rwc, the net.Conn.
+// 4. conn.buf, a bufio.Writer of default (4kB) bytes, writing to ->
+// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write
+//    and populates c.werr with it if so. but otherwise writes to:
+// 6. the rwc, the net.Conn.
 //
 // TODO(bradfitz): short-circuit some of the buffering when the
 // initial header contains both a Content-Type and Content-Length.
@@ -1027,6 +1032,12 @@ func (w *response) finishRequest() {
                // Did not write enough. Avoid getting out of sync.
                w.closeAfterReply = true
        }
+
+       // There was some error writing to the underlying connection
+       // during the request, so don't re-use this conn.
+       if w.conn.werr != nil {
+               w.closeAfterReply = true
+       }
 }
 
 func (w *response) Flush() {
@@ -2068,3 +2079,18 @@ func (c *loggingConn) Close() (err error) {
        log.Printf("%s.Close() = %v", c.name, err)
        return
 }
+
+// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr.
+// It only contains one field (and a pointer field at that), so it
+// fits in an interface value without an extra allocation.
+type checkConnErrorWriter struct {
+       c *conn
+}
+
+func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
+       n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil.
+       if err != nil && w.c.werr == nil {
+               w.c.werr = err
+       }
+       return
+}