]> Cypherpunks repositories - gostls13.git/commitdiff
[release-branch.go1.17] net/http/httptest: fix race in Server.Close
authorMaisem Ali <maisem@tailscale.com>
Mon, 21 Mar 2022 17:43:45 +0000 (17:43 +0000)
committerHeschi Kreinick <heschi@google.com>
Mon, 9 May 2022 20:17:13 +0000 (20:17 +0000)
When run with race detector the test fails without the fix.

For #51799
Fixes #52455

Change-Id: I273adb6d3a2b1e0d606b9c27ab4c6a9aa4aa8064
GitHub-Last-Rev: a5ddd146a2a65f2e817eed5133449c79b3af2562
GitHub-Pull-Request: golang/go#51805
Reviewed-on: https://go-review.googlesource.com/c/go/+/393974
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
(cherry picked from commit 1d19cea740a5a044848aaab3dc119f60c947be1d)
Reviewed-on: https://go-review.googlesource.com/c/go/+/401318
Reviewed-by: David Chase <drchase@google.com>
src/net/http/httptest/server.go
src/net/http/httptest/server_test.go

index 4f85ff55d89397e59f96cb7631ffa01cffe7ac13..1c0c0f698756953106c7742c58d0384b2ab9d30b 100644 (file)
@@ -317,21 +317,17 @@ func (s *Server) wrap() {
                s.mu.Lock()
                defer s.mu.Unlock()
 
-               // Keep Close from returning until the user's ConnState hook
-               // (if any) finishes. Without this, the call to forgetConn
-               // below might send the count to 0 before we run the hook.
-               s.wg.Add(1)
-               defer s.wg.Done()
-
                switch cs {
                case http.StateNew:
-                       s.wg.Add(1)
                        if _, exists := s.conns[c]; exists {
                                panic("invalid state transition")
                        }
                        if s.conns == nil {
                                s.conns = make(map[net.Conn]http.ConnState)
                        }
+                       // Add c to the set of tracked conns and increment it to the
+                       // waitgroup.
+                       s.wg.Add(1)
                        s.conns[c] = cs
                        if s.closed {
                                // Probably just a socket-late-binding dial from
@@ -358,7 +354,14 @@ func (s *Server) wrap() {
                                s.closeConn(c)
                        }
                case http.StateHijacked, http.StateClosed:
-                       s.forgetConn(c)
+                       // Remove c from the set of tracked conns and decrement it from the
+                       // waitgroup, unless it was previously removed.
+                       if _, ok := s.conns[c]; ok {
+                               delete(s.conns, c)
+                               // Keep Close from returning until the user's ConnState hook
+                               // (if any) finishes.
+                               defer s.wg.Done()
+                       }
                }
                if oldHook != nil {
                        oldHook(c, cs)
@@ -378,13 +381,3 @@ func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
                done <- struct{}{}
        }
 }
-
-// forgetConn removes c from the set of tracked conns and decrements it from the
-// waitgroup, unless it was previously removed.
-// s.mu must be held.
-func (s *Server) forgetConn(c net.Conn) {
-       if _, ok := s.conns[c]; ok {
-               delete(s.conns, c)
-               s.wg.Done()
-       }
-}
index 39568b358c2610e44eb48b517236582ed42a99a2..5313f65456c350e9f74b94b2070a0a36b46506f7 100644 (file)
@@ -9,6 +9,7 @@ import (
        "io"
        "net"
        "net/http"
+       "sync"
        "testing"
 )
 
@@ -203,6 +204,59 @@ func TestServerZeroValueClose(t *testing.T) {
        ts.Close() // tests that it doesn't panic
 }
 
+// Issue 51799: test hijacking a connection and then closing it
+// concurrently with closing the server.
+func TestCloseHijackedConnection(t *testing.T) {
+       hijacked := make(chan net.Conn)
+       ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               defer close(hijacked)
+               hj, ok := w.(http.Hijacker)
+               if !ok {
+                       t.Fatal("failed to hijack")
+               }
+               c, _, err := hj.Hijack()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               hijacked <- c
+       }))
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               req, err := http.NewRequest("GET", ts.URL, nil)
+               if err != nil {
+                       t.Log(err)
+               }
+               // Use a client not associated with the Server.
+               var c http.Client
+               resp, err := c.Do(req)
+               if err != nil {
+                       t.Log(err)
+                       return
+               }
+               resp.Body.Close()
+       }()
+
+       wg.Add(1)
+       conn := <-hijacked
+       go func(conn net.Conn) {
+               defer wg.Done()
+               // Close the connection and then inform the Server that
+               // we closed it.
+               conn.Close()
+               ts.Config.ConnState(conn, http.StateClosed)
+       }(conn)
+
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               ts.Close()
+       }()
+       wg.Wait()
+}
+
 func TestTLSServerWithHTTP2(t *testing.T) {
        modes := []struct {
                name      string