]> Cypherpunks repositories - gostls13.git/commitdiff
httptest: add NewTLSServer
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 4 Apr 2011 15:32:59 +0000 (08:32 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 4 Apr 2011 15:32:59 +0000 (08:32 -0700)
Enables the use of https servers in tests.

R=agl, rsc, agl1
CC=golang-dev
https://golang.org/cl/4284063

src/pkg/crypto/tls/tls.go
src/pkg/http/httptest/server.go
src/pkg/http/serve_test.go

index f66449c822535f15839a2c376f4256552615a752..7de44bbd244df80afe6451ef0d8754e22e60f4d7 100644 (file)
@@ -124,7 +124,16 @@ func LoadX509KeyPair(certFile string, keyFile string) (cert Certificate, err os.
        if err != nil {
                return
        }
+       keyPEMBlock, err := ioutil.ReadFile(keyFile)
+       if err != nil {
+               return
+       }
+       return X509KeyPair(certPEMBlock, keyPEMBlock)
+}
 
+// X509KeyPair parses a public/private key pair from a pair of
+// PEM encoded data.
+func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err os.Error) {
        var certDERBlock *pem.Block
        for {
                certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
@@ -141,11 +150,6 @@ func LoadX509KeyPair(certFile string, keyFile string) (cert Certificate, err os.
                return
        }
 
-       keyPEMBlock, err := ioutil.ReadFile(keyFile)
-       if err != nil {
-               return
-       }
-
        keyDERBlock, _ := pem.Decode(keyPEMBlock)
        if keyDERBlock == nil {
                err = os.ErrorString("crypto/tls: failed to parse key PEM data")
index 6e825a890d19382122f768161fa0f685839ec9ae..8e385d045a1a64b95c3ee5270738209da92d445d 100644 (file)
@@ -7,10 +7,13 @@
 package httptest
 
 import (
+       "crypto/rand"
+       "crypto/tls"
        "fmt"
        "http"
-       "os"
        "net"
+       "os"
+       "time"
 )
 
 // A Server is an HTTP server listening on a system-chosen port on the
@@ -18,6 +21,7 @@ import (
 type Server struct {
        URL      string // base URL of form http://ipaddr:port with no trailing slash
        Listener net.Listener
+       TLS      *tls.Config // nil if not using using TLS
 }
 
 // historyListener keeps track of all connections that it's ever
@@ -35,16 +39,21 @@ func (hs *historyListener) Accept() (c net.Conn, err os.Error) {
        return
 }
 
-// NewServer starts and returns a new Server.
-// The caller should call Close when finished, to shut it down.
-func NewServer(handler http.Handler) *Server {
-       ts := new(Server)
+func newLocalListener() net.Listener {
        l, err := net.Listen("tcp", "127.0.0.1:0")
        if err != nil {
                if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
                        panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
                }
        }
+       return l
+}
+
+// NewServer starts and returns a new Server.
+// The caller should call Close when finished, to shut it down.
+func NewServer(handler http.Handler) *Server {
+       ts := new(Server)
+       l := newLocalListener()
        ts.Listener = &historyListener{l, make([]net.Conn, 0)}
        ts.URL = "http://" + l.Addr().String()
        server := &http.Server{Handler: handler}
@@ -52,6 +61,32 @@ func NewServer(handler http.Handler) *Server {
        return ts
 }
 
+// NewTLSServer starts and returns a new Server using TLS.
+// The caller should call Close when finished, to shut it down.
+func NewTLSServer(handler http.Handler) *Server {
+       l := newLocalListener()
+       ts := new(Server)
+
+       cert, err := tls.X509KeyPair(localhostCert, localhostKey)
+       if err != nil {
+               panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+       }
+
+       ts.TLS = &tls.Config{
+               Rand:         rand.Reader,
+               Time:         time.Seconds,
+               NextProtos:   []string{"http/1.1"},
+               Certificates: []tls.Certificate{cert},
+       }
+       tlsListener := tls.NewListener(l, ts.TLS)
+
+       ts.Listener = &historyListener{tlsListener, make([]net.Conn, 0)}
+       ts.URL = "https://" + l.Addr().String()
+       server := &http.Server{Handler: handler}
+       go server.Serve(ts.Listener)
+       return ts
+}
+
 // Close shuts down the server.
 func (s *Server) Close() {
        s.Listener.Close()
@@ -68,3 +103,34 @@ func (s *Server) CloseClientConnections() {
                conn.Close()
        }
 }
+
+// localhostCert is a PEM-encoded TLS cert with SAN DNS names
+// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
+// of ASN.1 time).
+var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
+MIIBwTCCASugAwIBAgIBADALBgkqhkiG9w0BAQUwADAeFw0xMTAzMzEyMDI1MDda
+Fw00OTEyMzEyMzU5NTlaMAAwggCdMAsGCSqGSIb3DQEBAQOCAIwAMIIAhwKCAIB6
+oy4iT42G6qk+GGn5VL5JlnJT6ZG5cqaMNFaNGlIxNb6CPUZLKq2sM3gRaimsktIw
+nNAcNwQGHpe1tZo+J/Pl04JTt71Y/TTAxy7OX27aZf1Rpt0SjdZ7vTPnFDPNsHGe
+KBKvPt55l2+YKjkZmV7eRevsVbpkNvNGB+T5d4Ge/wIBA6NPME0wDgYDVR0PAQH/
+BAQDAgCgMA0GA1UdDgQGBAQBAgMEMA8GA1UdIwQIMAaABAECAwQwGwYDVR0RBBQw
+EoIJMTI3LjAuMC4xggVbOjoxXTALBgkqhkiG9w0BAQUDggCBAHC3gbdvc44vs+wD
+g2kONiENnx8WKc0UTGg/TOXS3gaRb+CUIQtHWja65l8rAfclEovjHgZ7gx8brO0W
+JuC6p3MUAKsgOssIrrRIx2rpnfcmFVMzguCmrMNVmKUAalw18Yp0F72xYAIitVQl
+kJrLdIhBajcJRYu/YGltHQRaXuVt
+-----END CERTIFICATE-----
+`)
+
+// localhostKey is the private key for localhostCert.
+var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIIBkgIBAQKCAIB6oy4iT42G6qk+GGn5VL5JlnJT6ZG5cqaMNFaNGlIxNb6CPUZL
+Kq2sM3gRaimsktIwnNAcNwQGHpe1tZo+J/Pl04JTt71Y/TTAxy7OX27aZf1Rpt0S
+jdZ7vTPnFDPNsHGeKBKvPt55l2+YKjkZmV7eRevsVbpkNvNGB+T5d4Ge/wIBAwKC
+AIBRwh7Bil5Z8cYpZZv7jdQxDvbim7Z7ocRdeDmzZuF2I9RW04QyHHPIIlALnBvI
+YeF1veASz1gEFGUjzmbUGqKYSbCoTzXoev+F4bmbRxcX9sOmtslqvhMSHRSzA5NH
+aDVI3Hn4wvBVD8gePu8ACWqvPGbCiql11OKCMfjlPn2uuwJAx/24/F5DjXZ6hQQ7
+HxScOxKrpx5WnA9r1wZTltOTZkhRRzuLc21WJeE3M15QUdWi3zZxCKRFoth65HEs
+jy9YHQJAnPueRI44tz79b5QqVbeaOMUr7ZCb1Kp0uo6G+ANPLdlfliAupwij2eIz
+mHRJOWk0jBtXfRft1McH2H51CpXAyw==
+-----END RSA PRIVATE KEY-----
+`)
index b0e26e533596b423d0f7ffb61fa2ef6970c62967..cf889553fb7852f25d64efc9d7bcfa5fcbfb574d 100644 (file)
@@ -507,3 +507,30 @@ func TestHeadResponses(t *testing.T) {
                t.Errorf("got unexpected body %q", string(body))
        }
 }
+
+func TestTLSServer(t *testing.T) {
+       ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               fmt.Fprintf(w, "tls=%v", r.TLS != nil)
+       }))
+       defer ts.Close()
+       if !strings.HasPrefix(ts.URL, "https://") {
+               t.Fatalf("expected test TLS server to start with https://, got %q", ts.URL)
+       }
+       res, _, err := Get(ts.URL)
+       if err != nil {
+               t.Error(err)
+       }
+       if res == nil {
+               t.Fatalf("got nil Response")
+       }
+       if res.Body == nil {
+               t.Fatalf("got nil Response.Body")
+       }
+       body, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Error(err)
+       }
+       if e, g := "tls=true", string(body); e != g {
+               t.Errorf("expected body %q; got %q", e, g)
+       }
+}