"net/url"
"os"
"os/exec"
+ "path/filepath"
"reflect"
+ "regexp"
"runtime"
"runtime/debug"
"sort"
}
}
+// Issue 30803: ensure that TimeoutHandler logs spurious
+// WriteHeader calls, for consistency with other Handlers.
+func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ pc, curFile, _, _ := runtime.Caller(0)
+ curFileBaseName := filepath.Base(curFile)
+ testFuncName := runtime.FuncForPC(pc).Name()
+
+ timeoutMsg := "timed out here!"
+ maxTimeout := 200 * time.Millisecond
+
+ tests := []struct {
+ name string
+ sleepTime time.Duration
+ wantResp string
+ }{
+ {
+ name: "return before timeout",
+ sleepTime: 0,
+ wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
+ },
+ {
+ name: "return after timeout",
+ sleepTime: maxTimeout * 2,
+ wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
+ len(timeoutMsg), timeoutMsg),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var lastSpuriousLine int32
+
+ sh := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(404)
+ w.WriteHeader(404)
+ w.WriteHeader(404)
+ w.WriteHeader(404)
+ _, _, line, _ := runtime.Caller(0)
+ atomic.StoreInt32(&lastSpuriousLine, int32(line))
+
+ <-time.After(tt.sleepTime)
+ })
+
+ logBuf := new(bytes.Buffer)
+ srvLog := log.New(logBuf, "", 0)
+ th := TimeoutHandler(sh, maxTimeout, timeoutMsg)
+ cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, th, optWithServerLog(srvLog))
+ defer cst.close()
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ // Deliberately removing the "Date" header since it is highly ephemeral
+ // and will cause failure if we try to match it exactly.
+ res.Header.Del("Date")
+ res.Header.Del("Content-Type")
+
+ // Match the response.
+ blob, _ := httputil.DumpResponse(res, true)
+ if g, w := string(blob), tt.wantResp; g != w {
+ t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
+ }
+
+ // Given 4 w.WriteHeader calls, only the first one is valid
+ // and the rest should be reported as the 3 spurious logs.
+ logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
+ if g, w := len(logEntries), 3; g != w {
+ blob, _ := json.MarshalIndent(logEntries, "", " ")
+ t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
+ }
+
+ // Now ensure that the regexes match exactly.
+ // "http: superfluous response.WriteHeader call from <fn>.func\d.\d (<curFile>:lastSpuriousLine-[1, 3]"
+ for i, logEntry := range logEntries {
+ wantLine := atomic.LoadInt32(&lastSpuriousLine) - 3 + int32(i)
+ pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
+ testFuncName, curFileBaseName, wantLine)
+ re := regexp.MustCompile(pat)
+ if !re.MatchString(logEntry) {
+ t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
+ }
+ }
+ })
+ }
+}
+
// fetchWireResponse is a helper for dialing to host,
// sending http1ReqBody as the payload and retrieving
// the response as it was sent on the wire.
r = r.WithContext(ctx)
done := make(chan struct{})
tw := &timeoutWriter{
- w: w,
- h: make(Header),
+ w: w,
+ h: make(Header),
+ req: r,
}
panicChan := make(chan interface{}, 1)
go func() {
w ResponseWriter
h Header
wbuf bytes.Buffer
+ req *Request
mu sync.Mutex
timedOut bool
return 0, ErrHandlerTimeout
}
if !tw.wroteHeader {
- tw.writeHeader(StatusOK)
+ tw.writeHeaderLocked(StatusOK)
}
return tw.wbuf.Write(p)
}
-func (tw *timeoutWriter) WriteHeader(code int) {
+func (tw *timeoutWriter) writeHeaderLocked(code int) {
checkWriteHeaderCode(code)
- tw.mu.Lock()
- defer tw.mu.Unlock()
- if tw.timedOut || tw.wroteHeader {
+
+ switch {
+ case tw.timedOut:
return
+ case tw.wroteHeader:
+ if tw.req != nil {
+ caller := relevantCaller()
+ logf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
+ }
+ default:
+ tw.wroteHeader = true
+ tw.code = code
}
- tw.writeHeader(code)
}
-func (tw *timeoutWriter) writeHeader(code int) {
- tw.wroteHeader = true
- tw.code = code
+func (tw *timeoutWriter) WriteHeader(code int) {
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ tw.writeHeaderLocked(code)
}
// onceCloseListener wraps a net.Listener, protecting it from