]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: add Client and Certificate methods to Server
authorJohan Brandhorst <johan.brandhorst@gmail.com>
Wed, 21 Dec 2016 13:49:04 +0000 (13:49 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 3 Mar 2017 21:02:17 +0000 (21:02 +0000)
Adds a function for easily accessing the x509.Certificate
of a Server, if there is one. Also adds a helper function
for getting a http.Client suitable for use with the server.

This makes the steps required to test a httptest
TLS server simpler.

Fixes #18411

Change-Id: I2e78fe1e54e31bed9c641be2d9a099f698c7bbde
Reviewed-on: https://go-review.googlesource.com/34639
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/go/build/deps_test.go
src/net/http/httptest/example_test.go
src/net/http/httptest/server.go
src/net/http/httptest/server_test.go

index f8ba53288e660da5bc6a759e5f0cf72ad5ac0039..2adc06f39bc84dc670eee24b1c844acc2e638906 100644 (file)
@@ -411,7 +411,7 @@ var pkgDeps = map[string][]string{
        "net/http/cgi":       {"L4", "NET", "OS", "crypto/tls", "net/http", "regexp"},
        "net/http/cookiejar": {"L4", "NET", "net/http"},
        "net/http/fcgi":      {"L4", "NET", "OS", "net/http", "net/http/cgi"},
-       "net/http/httptest":  {"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal"},
+       "net/http/httptest":  {"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509"},
        "net/http/httputil":  {"L4", "NET", "OS", "context", "net/http", "net/http/internal"},
        "net/http/pprof":     {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"},
        "net/rpc":            {"L4", "NET", "encoding/gob", "html/template", "net/http"},
index bd2c49642b6cc575dda21c212ab9319bdcbd107b..e3d392130e0fa66554d440a6748d1addfb08f643 100644 (file)
@@ -54,3 +54,25 @@ func ExampleServer() {
        fmt.Printf("%s", greeting)
        // Output: Hello, client
 }
+
+func ExampleNewTLSServer() {
+       ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               fmt.Fprintln(w, "Hello, client")
+       }))
+       defer ts.Close()
+
+       client := ts.Client()
+       res, err := client.Get(ts.URL)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       greeting, err := ioutil.ReadAll(res.Body)
+       res.Body.Close()
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       fmt.Printf("%s", greeting)
+       // Output: Hello, client
+}
index 711821433bf22dd6ff47fbf581af80f074034517..56ad18ee9ba8b8cbac8381b4c929a02af287c4ed 100644 (file)
@@ -9,6 +9,7 @@ package httptest
 import (
        "bytes"
        "crypto/tls"
+       "crypto/x509"
        "flag"
        "fmt"
        "log"
@@ -35,6 +36,9 @@ type Server struct {
        // before Start or StartTLS.
        Config *http.Server
 
+       // certificate is a parsed version of the TLS config certificate, if present.
+       certificate *x509.Certificate
+
        // wg counts the number of outstanding HTTP requests on this server.
        // Close blocks until all requests are finished.
        wg sync.WaitGroup
@@ -42,6 +46,10 @@ type Server struct {
        mu     sync.Mutex // guards closed and conns
        closed bool
        conns  map[net.Conn]http.ConnState // except terminal states
+
+       // client is configured for use with the server.
+       // Its transport is automatically closed when Close is called.
+       client *http.Client
 }
 
 func newLocalListener() net.Listener {
@@ -85,6 +93,7 @@ func NewUnstartedServer(handler http.Handler) *Server {
        return &Server{
                Listener: newLocalListener(),
                Config:   &http.Server{Handler: handler},
+               client:   &http.Client{},
        }
 }
 
@@ -124,6 +133,17 @@ func (s *Server) StartTLS() {
        if len(s.TLS.Certificates) == 0 {
                s.TLS.Certificates = []tls.Certificate{cert}
        }
+       s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
+       if err != nil {
+               panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+       }
+       certpool := x509.NewCertPool()
+       certpool.AddCert(s.certificate)
+       s.client.Transport = &http.Transport{
+               TLSClientConfig: &tls.Config{
+                       RootCAs: certpool,
+               },
+       }
        s.Listener = tls.NewListener(s.Listener, s.TLS)
        s.URL = "https://" + s.Listener.Addr().String()
        s.wrap()
@@ -186,6 +206,11 @@ func (s *Server) Close() {
                t.CloseIdleConnections()
        }
 
+       // Also close the client idle connections.
+       if t, ok := s.client.Transport.(closeIdleTransport); ok {
+               t.CloseIdleConnections()
+       }
+
        s.wg.Wait()
 }
 
@@ -228,6 +253,19 @@ func (s *Server) CloseClientConnections() {
        }
 }
 
+// Certificate returns the certificate used by the server, or nil if
+// the server doesn't use TLS.
+func (s *Server) Certificate() *x509.Certificate {
+       return s.certificate
+}
+
+// Client returns an HTTP client configured for making requests to the server.
+// It is configured to trust the server's TLS test certificate and will
+// close its idle connections on Server.Close.
+func (s *Server) Client() *http.Client {
+       return s.client
+}
+
 func (s *Server) goServe() {
        s.wg.Add(1)
        go func() {
index d032c5983b710351a40cf522bd7a9e0e00be93d2..7d80fa15dd8cd21e50ac6c7fd59112c71fd62b07 100644 (file)
@@ -22,6 +22,7 @@ func TestServer(t *testing.T) {
                t.Fatal(err)
        }
        got, err := ioutil.ReadAll(res.Body)
+       res.Body.Close()
        if err != nil {
                t.Fatal(err)
        }
@@ -98,3 +99,25 @@ func TestServerCloseClientConnections(t *testing.T) {
                t.Fatalf("Unexpected response: %#v", res)
        }
 }
+
+// Tests that the Server.Client method works and returns an http.Client that can hit
+// NewTLSServer without cert warnings.
+func TestServerClient(t *testing.T) {
+       ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               w.Write([]byte("hello"))
+       }))
+       defer ts.Close()
+       client := ts.Client()
+       res, err := client.Get(ts.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       got, err := ioutil.ReadAll(res.Body)
+       res.Body.Close()
+       if err != nil {
+               t.Fatal(err)
+       }
+       if string(got) != "hello" {
+               t.Errorf("got %q, want hello", string(got))
+       }
+}