}
}
+type tlogWriter struct{ t *testing.T }
+
+func (w tlogWriter) Write(p []byte) (int, error) {
+ w.t.Log(string(p))
+ return len(p), nil
+}
+
+func TestWriteHeaderSwitchingProtocols(t *testing.T) {
+ run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
+}
+func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
+ const wantBody = "want"
+ const wantUpgrade = "someProto"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "Upgrade")
+ w.Header().Set("Upgrade", wantUpgrade)
+ w.WriteHeader(StatusSwitchingProtocols)
+ NewResponseController(w).Flush()
+
+ // Writing headers or the body after sending a 101 header should fail.
+ w.WriteHeader(200)
+ if _, err := w.Write([]byte("x")); err == nil {
+ t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
+ }
+
+ c, _, err := NewResponseController(w).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ defer c.Close()
+ if _, err := c.Write([]byte(wantBody)); err != nil {
+ t.Errorf("Write to hijacked body: %v", err)
+ }
+ }), func(ts *httptest.Server) {
+ // Don't spam log with warning about superfluous WriteHeader call.
+ ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
+ }).ts
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("net.Dial: %v", err)
+ }
+ _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ if err != nil {
+ t.Fatalf("conn.Write: %v", err)
+ }
+ defer conn.Close()
+
+ r := bufio.NewReader(conn)
+ res, err := ReadResponse(r, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal("ReadResponse error:", err)
+ }
+ if res.StatusCode != StatusSwitchingProtocols {
+ t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
+ }
+ if got := res.Header.Get("Upgrade"); got != wantUpgrade {
+ t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
+ }
+ body, err := io.ReadAll(r)
+ if err != nil {
+ t.Error(err)
+ }
+ if string(body) != wantBody {
+ t.Errorf("Response body = %q, want %q", string(body), wantBody)
+ }
+}
+
func TestMuxRedirectRelative(t *testing.T) {
setParallel(t)
req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))