]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: handle WriteHeader(101) as a non-informational header
authorDamien Neil <dneil@google.com>
Tue, 18 Apr 2023 15:50:02 +0000 (08:50 -0700)
committerGopher Robot <gobot@golang.org>
Mon, 15 May 2023 17:35:28 +0000 (17:35 +0000)
Prior to Go 1.19 adding support for sending 1xx informational headers
with ResponseWriter.WriteHeader, WriteHeader(101) would send a 101
status and disable further writes to the response. This behavior
was not documented, but is intentional: Writing to the response
body explicitly checks to see if a 101 status has been sent before
writing.

Restore the pre-1.19 behavior when writing a 101 Switching Protocols
header: The header is sent, no subsequent headers are sent, and
subsequent writes to the response body fail.

For #59564

Change-Id: I72c116f88405b1ef5067b510f8c7cff0b36951ee
Reviewed-on: https://go-review.googlesource.com/c/go/+/485775
Reviewed-by: Bryan Mills <bcmills@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/net/http/serve_test.go
src/net/http/server.go

index a21518b56361f44cb17c9658ef08cbbe51905a6f..12f6b768bd6a5effa252435545c2e70ae284be91 100644 (file)
@@ -6431,6 +6431,75 @@ func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
        }
 }
 
+type tlogWriter struct{ t *testing.T }
+
+func (w tlogWriter) Write(p []byte) (int, error) {
+       w.t.Log(string(p))
+       return len(p), nil
+}
+
+func TestWriteHeaderSwitchingProtocols(t *testing.T) {
+       run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
+}
+func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
+       const wantBody = "want"
+       const wantUpgrade = "someProto"
+       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.Header().Set("Connection", "Upgrade")
+               w.Header().Set("Upgrade", wantUpgrade)
+               w.WriteHeader(StatusSwitchingProtocols)
+               NewResponseController(w).Flush()
+
+               // Writing headers or the body after sending a 101 header should fail.
+               w.WriteHeader(200)
+               if _, err := w.Write([]byte("x")); err == nil {
+                       t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
+               }
+
+               c, _, err := NewResponseController(w).Hijack()
+               if err != nil {
+                       t.Errorf("Hijack: %v", err)
+                       return
+               }
+               defer c.Close()
+               if _, err := c.Write([]byte(wantBody)); err != nil {
+                       t.Errorf("Write to hijacked body: %v", err)
+               }
+       }), func(ts *httptest.Server) {
+               // Don't spam log with warning about superfluous WriteHeader call.
+               ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
+       }).ts
+
+       conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+       if err != nil {
+               t.Fatalf("net.Dial: %v", err)
+       }
+       _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+       if err != nil {
+               t.Fatalf("conn.Write: %v", err)
+       }
+       defer conn.Close()
+
+       r := bufio.NewReader(conn)
+       res, err := ReadResponse(r, &Request{Method: "GET"})
+       if err != nil {
+               t.Fatal("ReadResponse error:", err)
+       }
+       if res.StatusCode != StatusSwitchingProtocols {
+               t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
+       }
+       if got := res.Header.Get("Upgrade"); got != wantUpgrade {
+               t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
+       }
+       body, err := io.ReadAll(r)
+       if err != nil {
+               t.Error(err)
+       }
+       if string(body) != wantBody {
+               t.Errorf("Response body = %q, want %q", string(body), wantBody)
+       }
+}
+
 func TestMuxRedirectRelative(t *testing.T) {
        setParallel(t)
        req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
index e82669a18053ee88493e570040fffed26c47bb08..680c5f68f425fa2a4e01af43d5180e3e3aa55389 100644 (file)
@@ -1154,8 +1154,11 @@ func (w *response) WriteHeader(code int) {
        }
        checkWriteHeaderCode(code)
 
-       // Handle informational headers
-       if code >= 100 && code <= 199 {
+       // Handle informational headers.
+       //
+       // We shouldn't send any further headers after 101 Switching Protocols,
+       // so it takes the non-informational path.
+       if code >= 100 && code <= 199 && code != StatusSwitchingProtocols {
                // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read()
                if code == 100 && w.canWriteContinue.Load() {
                        w.writeContinueMu.Lock()