]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: also use Server.ReadHeaderTimeout for TLS handshake deadline
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 14 Oct 2021 15:45:16 +0000 (08:45 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 14 Oct 2021 22:15:24 +0000 (22:15 +0000)
Fixes #48120

Change-Id: I72e89af8aaf3310e348d8ab639925ce0bf84204d
Reviewed-on: https://go-review.googlesource.com/c/go/+/355870
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
src/net/http/server.go
src/net/http/server_test.go

index 55fd4ae22f90013e593206791a7a80ba1da36327..e9b0b4d9bd9a28c934c6b8ab3373090b09880163 100644 (file)
@@ -865,6 +865,28 @@ func (srv *Server) initialReadLimitSize() int64 {
        return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
 }
 
+// tlsHandshakeTimeout returns the time limit permitted for the TLS
+// handshake, or zero for unlimited.
+//
+// It returns the minimum of any positive ReadHeaderTimeout,
+// ReadTimeout, or WriteTimeout.
+func (srv *Server) tlsHandshakeTimeout() time.Duration {
+       var ret time.Duration
+       for _, v := range [...]time.Duration{
+               srv.ReadHeaderTimeout,
+               srv.ReadTimeout,
+               srv.WriteTimeout,
+       } {
+               if v <= 0 {
+                       continue
+               }
+               if ret == 0 || v < ret {
+                       ret = v
+               }
+       }
+       return ret
+}
+
 // wrapper around io.ReadCloser which on first read, sends an
 // HTTP/1.1 100 Continue header
 type expectContinueReader struct {
@@ -1816,11 +1838,11 @@ func (c *conn) serve(ctx context.Context) {
        }()
 
        if tlsConn, ok := c.rwc.(*tls.Conn); ok {
-               if d := c.server.ReadTimeout; d > 0 {
-                       c.rwc.SetReadDeadline(time.Now().Add(d))
-               }
-               if d := c.server.WriteTimeout; d > 0 {
-                       c.rwc.SetWriteDeadline(time.Now().Add(d))
+               tlsTO := c.server.tlsHandshakeTimeout()
+               if tlsTO > 0 {
+                       dl := time.Now().Add(tlsTO)
+                       c.rwc.SetReadDeadline(dl)
+                       c.rwc.SetWriteDeadline(dl)
                }
                if err := tlsConn.HandshakeContext(ctx); err != nil {
                        // If the handshake failed due to the client not speaking
@@ -1834,6 +1856,11 @@ func (c *conn) serve(ctx context.Context) {
                        c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
                        return
                }
+               // Restore Conn-level deadlines.
+               if tlsTO > 0 {
+                       c.rwc.SetReadDeadline(time.Time{})
+                       c.rwc.SetWriteDeadline(time.Time{})
+               }
                c.tlsState = new(tls.ConnectionState)
                *c.tlsState = tlsConn.ConnectionState()
                if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) {
index 0132f3ba5fbd77611fb22d9c965f98f98afccb33..d17c5c1e7ef5e6329587d37761f38f3d8329825f 100644 (file)
@@ -9,8 +9,61 @@ package http
 import (
        "fmt"
        "testing"
+       "time"
 )
 
+func TestServerTLSHandshakeTimeout(t *testing.T) {
+       tests := []struct {
+               s    *Server
+               want time.Duration
+       }{
+               {
+                       s:    &Server{},
+                       want: 0,
+               },
+               {
+                       s: &Server{
+                               ReadTimeout: -1,
+                       },
+                       want: 0,
+               },
+               {
+                       s: &Server{
+                               ReadTimeout: 5 * time.Second,
+                       },
+                       want: 5 * time.Second,
+               },
+               {
+                       s: &Server{
+                               ReadTimeout:  5 * time.Second,
+                               WriteTimeout: -1,
+                       },
+                       want: 5 * time.Second,
+               },
+               {
+                       s: &Server{
+                               ReadTimeout:  5 * time.Second,
+                               WriteTimeout: 4 * time.Second,
+                       },
+                       want: 4 * time.Second,
+               },
+               {
+                       s: &Server{
+                               ReadTimeout:       5 * time.Second,
+                               ReadHeaderTimeout: 2 * time.Second,
+                               WriteTimeout:      4 * time.Second,
+                       },
+                       want: 2 * time.Second,
+               },
+       }
+       for i, tt := range tests {
+               got := tt.s.tlsHandshakeTimeout()
+               if got != tt.want {
+                       t.Errorf("%d. got %v; want %v", i, got, tt.want)
+               }
+       }
+}
+
 func BenchmarkServerMatch(b *testing.B) {
        fn := func(w ResponseWriter, r *Request) {
                fmt.Fprintf(w, "OK")