outreq.Header.Set("User-Agent", "")
}
+ var (
+ roundTripMutex sync.Mutex
+ roundTripDone bool
+ )
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ roundTripMutex.Lock()
+ defer roundTripMutex.Unlock()
+ if roundTripDone {
+ // If RoundTrip has returned, don't try to further modify
+ // the ResponseWriter's header map.
+ return nil
+ }
h := rw.Header()
copyHeader(h, http.Header(header))
rw.WriteHeader(code)
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
+ roundTripMutex.Lock()
+ roundTripDone = true
+ roundTripMutex.Unlock()
if err != nil {
p.getErrorHandler()(rw, outreq, err)
return
}
}
+func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) {
+ // https://go.dev/issue/65123: We use httptrace.Got1xxResponse to capture 1xx responses
+ // and proxy them. httptrace handlers can execute after RoundTrip returns, in particular
+ // after experiencing connection errors. When this happens, we shouldn't modify the
+ // ResponseWriter headers after ReverseProxy.ServeHTTP returns.
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ for i := 0; i < 5; i++ {
+ w.WriteHeader(103)
+ }
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+
+ rw := &testResponseWriter{}
+ func() {
+ // Cancel the request (and cause RoundTrip to return) immediately upon
+ // seeing a 1xx response.
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ cancel()
+ return nil
+ },
+ })
+
+ req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil)
+ proxyHandler.ServeHTTP(rw, req)
+ }()
+ // Trigger data race while iterating over response headers.
+ // When run with -race, this causes the condition in https://go.dev/issue/65123 often
+ // enough to detect reliably.
+ for _ = range rw.Header() {
+ }
+}
+
func Test1xxResponses(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
}
}
}
+
+type testResponseWriter struct {
+ h http.Header
+ writeHeader func(int)
+ write func([]byte) (int, error)
+}
+
+func (rw *testResponseWriter) Header() http.Header {
+ if rw.h == nil {
+ rw.h = make(http.Header)
+ }
+ return rw.h
+}
+
+func (rw *testResponseWriter) WriteHeader(statusCode int) {
+ if rw.writeHeader != nil {
+ rw.writeHeader(statusCode)
+ }
+}
+
+func (rw *testResponseWriter) Write(p []byte) (int, error) {
+ if rw.write != nil {
+ return rw.write(p)
+ }
+ return len(p), nil
+}