"mime"
"net"
"net/http"
+ "net/http/httptrace"
"net/http/internal/ascii"
"net/textproto"
"net/url"
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
+//
+// 1xx responses are forwarded to the client if the underlying
+// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
// Rewrite must be a function which modifies
// the request into a new request to be sent
outreq.Header.Set("User-Agent", "")
}
+ trace := &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ h := rw.Header()
+ copyHeader(h, http.Header(header))
+ rw.WriteHeader(code)
+
+ // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
+ for k := range h {
+ delete(h, k)
+ }
+
+ return nil
+ },
+ }
+ outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
+
res, err := transport.RoundTrip(outreq)
if err != nil {
p.getErrorHandler()(rw, outreq, err)
"log"
"net/http"
"net/http/httptest"
+ "net/http/httptrace"
"net/http/internal/ascii"
+ "net/textproto"
"net/url"
"os"
"reflect"
t.Errorf("got response %q, want %q", got, want)
}
}
+
+func Test1xxResponses(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ h := w.Header()
+ h.Add("Link", "</style.css>; rel=preload; as=style")
+ h.Add("Link", "</script.js>; rel=preload; as=script")
+ w.WriteHeader(http.StatusEarlyHints)
+
+ h.Add("Link", "</foo.js>; rel=preload; as=script")
+ w.WriteHeader(http.StatusProcessing)
+
+ w.Write([]byte("Hello"))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ checkLinkHeaders := func(t *testing.T, expected, got []string) {
+ t.Helper()
+
+ if len(expected) != len(got) {
+ t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
+ }
+
+ for i := range expected {
+ if i >= len(got) {
+ t.Errorf("Expected %q link header; got nothing", expected[i])
+
+ continue
+ }
+
+ if expected[i] != got[i] {
+ t.Errorf("Expected %q link header; got %q", expected[i], got[i])
+ }
+ }
+ }
+
+ var respCounter uint8
+ trace := &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ switch code {
+ case http.StatusEarlyHints:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
+ case http.StatusProcessing:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
+ default:
+ t.Error("Unexpected 1xx response")
+ }
+
+ respCounter++
+
+ return nil
+ },
+ }
+ req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
+
+ res, err := frontendClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+
+ defer res.Body.Close()
+
+ if respCounter != 2 {
+ t.Errorf("Expected 2 1xx responses; got %d", respCounter)
+ }
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
+
+ body, _ := io.ReadAll(res.Body)
+ if string(body) != "Hello" {
+ t.Errorf("Read body %q; want Hello", body)
+ }
+}