]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Server.ErrorLog; log and test TLS handshake errors
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 28 Feb 2014 20:12:51 +0000 (12:12 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 28 Feb 2014 20:12:51 +0000 (12:12 -0800)
Fixes #7291

LGTM=agl
R=golang-codereviews, agl
CC=agl, golang-codereviews
https://golang.org/cl/70250044

src/pkg/net/http/client_test.go
src/pkg/net/http/serve_test.go
src/pkg/net/http/server.go

index e5ad39c77418a568adff50ac258a236d31147605..b81af1a479feede344953fb0da09121785c706bd 100644 (file)
@@ -15,6 +15,7 @@ import (
        "fmt"
        "io"
        "io/ioutil"
+       "log"
        "net"
        . "net/http"
        "net/http/httptest"
@@ -23,6 +24,7 @@ import (
        "strings"
        "sync"
        "testing"
+       "time"
 )
 
 var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -54,6 +56,13 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) {
        }
 }
 
+type chanWriter chan string
+
+func (w chanWriter) Write(p []byte) (n int, err error) {
+       w <- string(p)
+       return len(p), nil
+}
+
 func TestClient(t *testing.T) {
        defer afterTest(t)
        ts := httptest.NewServer(robotsTxtHandler)
@@ -564,6 +573,8 @@ func TestClientInsecureTransport(t *testing.T) {
        ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
                w.Write([]byte("Hello"))
        }))
+       errc := make(chanWriter, 10) // but only expecting 1
+       ts.Config.ErrorLog = log.New(errc, "", 0)
        defer ts.Close()
 
        // TODO(bradfitz): add tests for skipping hostname checks too?
@@ -585,6 +596,16 @@ func TestClientInsecureTransport(t *testing.T) {
                        res.Body.Close()
                }
        }
+
+       select {
+       case v := <-errc:
+               if !strings.Contains(v, "bad certificate") {
+                       t.Errorf("expected an error log message containing 'bad certificate'; got %q", v)
+               }
+       case <-time.After(5 * time.Second):
+               t.Errorf("timeout waiting for logged error")
+       }
+
 }
 
 func TestClientErrorWithRequestURI(t *testing.T) {
@@ -635,6 +656,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
        defer afterTest(t)
        ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
        defer ts.Close()
+       errc := make(chanWriter, 10) // but only expecting 1
+       ts.Config.ErrorLog = log.New(errc, "", 0)
 
        trans := newTLSTransport(t, ts)
        trans.TLSClientConfig.ServerName = "badserver"
@@ -646,6 +669,14 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
        if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
                t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
        }
+       select {
+       case v := <-errc:
+               if !strings.Contains(v, "bad certificate") {
+                       t.Errorf("expected an error log message containing 'bad certificate'; got %q", v)
+               }
+       case <-time.After(5 * time.Second):
+               t.Errorf("timeout waiting for logged error")
+       }
 }
 
 // Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName
index c8ca03e07d1406e6a5a82517528451063550a910..4fd6ff234dba94ecce448676173b5a6b9f8ebd6e 100644 (file)
@@ -851,7 +851,9 @@ func TestTLSHandshakeTimeout(t *testing.T) {
        }
        defer afterTest(t)
        ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+       errc := make(chanWriter, 10) // but only expecting 1
        ts.Config.ReadTimeout = 250 * time.Millisecond
+       ts.Config.ErrorLog = log.New(errc, "", 0)
        ts.StartTLS()
        defer ts.Close()
        conn, err := net.Dial("tcp", ts.Listener.Addr().String())
@@ -866,6 +868,14 @@ func TestTLSHandshakeTimeout(t *testing.T) {
                        t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
                }
        })
+       select {
+       case v := <-errc:
+               if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
+                       t.Errorf("expected a TLS handshake timeout error; got %q", v)
+               }
+       case <-time.After(5 * time.Second):
+               t.Errorf("timeout waiting for logged error")
+       }
 }
 
 func TestTLSServer(t *testing.T) {
@@ -878,6 +888,7 @@ func TestTLSServer(t *testing.T) {
                        }
                }
        }))
+       ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
        defer ts.Close()
 
        // Connect an idle TCP connection to this server before we run
index ba3a530adcdf61c3e7f5243c67d19947574e372d..b77ec6cb6fba6a6fe1adb4089d2a812c929507c7 100644 (file)
@@ -615,11 +615,11 @@ const maxPostHandlerReadBytes = 256 << 10
 
 func (w *response) WriteHeader(code int) {
        if w.conn.hijacked() {
-               log.Print("http: response.WriteHeader on hijacked connection")
+               w.conn.server.logf("http: response.WriteHeader on hijacked connection")
                return
        }
        if w.wroteHeader {
-               log.Print("http: multiple response.WriteHeader calls")
+               w.conn.server.logf("http: multiple response.WriteHeader calls")
                return
        }
        w.wroteHeader = true
@@ -634,7 +634,7 @@ func (w *response) WriteHeader(code int) {
                if err == nil && v >= 0 {
                        w.contentLength = v
                } else {
-                       log.Printf("http: invalid Content-Length of %q", cl)
+                       w.conn.server.logf("http: invalid Content-Length of %q", cl)
                        w.handlerHeader.Del("Content-Length")
                }
        }
@@ -817,7 +817,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
        if hasCL && hasTE && te != "identity" {
                // TODO: return an error if WriteHeader gets a return parameter
                // For now just ignore the Content-Length.
-               log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
+               w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
                        te, w.contentLength)
                delHeader("Content-Length")
                hasCL = false
@@ -963,7 +963,7 @@ func (w *response) WriteString(data string) (n int, err error) {
 // either dataB or dataS is non-zero.
 func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
        if w.conn.hijacked() {
-               log.Print("http: response.Write on hijacked connection")
+               w.conn.server.logf("http: response.Write on hijacked connection")
                return 0, ErrHijacked
        }
        if !w.wroteHeader {
@@ -1096,7 +1096,7 @@ func (c *conn) serve() {
                        const size = 64 << 10
                        buf := make([]byte, size)
                        buf = buf[:runtime.Stack(buf, false)]
-                       log.Printf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
+                       c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
                }
                if !c.hijacked() {
                        c.close()
@@ -1112,6 +1112,7 @@ func (c *conn) serve() {
                        c.rwc.SetWriteDeadline(time.Now().Add(d))
                }
                if err := tlsConn.Handshake(); err != nil {
+                       c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
                        return
                }
                c.tlsState = new(tls.ConnectionState)
@@ -1604,6 +1605,12 @@ type Server struct {
        // ConnState type and associated constants for details.
        ConnState func(net.Conn, ConnState)
 
+       // ErrorLog specifies an optional logger for errors accepting
+       // connections and unexpected behavior from handlers.
+       // If nil, logging goes to os.Stderr via the log package's
+       // standard logger.
+       ErrorLog *log.Logger
+
        disableKeepAlives int32 // accessed atomically.
 }
 
@@ -1704,7 +1711,7 @@ func (srv *Server) Serve(l net.Listener) error {
                                if max := 1 * time.Second; tempDelay > max {
                                        tempDelay = max
                                }
-                               log.Printf("http: Accept error: %v; retrying in %v", e, tempDelay)
+                               srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay)
                                time.Sleep(tempDelay)
                                continue
                        }
@@ -1735,6 +1742,14 @@ func (s *Server) SetKeepAlivesEnabled(v bool) {
        }
 }
 
+func (s *Server) logf(format string, args ...interface{}) {
+       if s.ErrorLog != nil {
+               s.ErrorLog.Printf(format, args...)
+       } else {
+               log.Printf(format, args...)
+       }
+}
+
 // ListenAndServe listens on the TCP network address addr
 // and then calls Serve with handler to handle requests
 // on incoming connections.  Handler is typically nil,