]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: Set TLSClientConfig.ServerName on every HTTP request.
authorDave Borowitz <dborowitz@google.com>
Wed, 22 Aug 2012 16:15:41 +0000 (09:15 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 22 Aug 2012 16:15:41 +0000 (09:15 -0700)
This makes SNI "just work" for callers using the standard http.Client.

Since we now have a test that depends on the httptest.Server cert, change
the cert to be a CA (keeping all other fields the same).

R=bradfitz
CC=agl, dsymonds, gobot, golang-dev
https://golang.org/cl/6448154

src/pkg/net/http/client_test.go
src/pkg/net/http/httptest/server.go
src/pkg/net/http/transport.go

index da7a44da7a50857b8b237acca33dbae8933eb8a4..c61b17d289b334fa8eb272d13fba74c653b4cbe8 100644 (file)
@@ -8,6 +8,7 @@ package http_test
 
 import (
        "crypto/tls"
+       "crypto/x509"
        "errors"
        "fmt"
        "io"
@@ -470,3 +471,49 @@ func TestClientErrorWithRequestURI(t *testing.T) {
                t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
        }
 }
+
+func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
+       certs := x509.NewCertPool()
+       for _, c := range ts.TLS.Certificates {
+               roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+               if err != nil {
+                       t.Fatalf("error parsing server's root cert: %v", err)
+               }
+               for _, root := range roots {
+                       certs.AddCert(root)
+               }
+       }
+       return &Transport{
+               TLSClientConfig: &tls.Config{RootCAs: certs},
+       }
+}
+
+func TestClientWithCorrectTLSServerName(t *testing.T) {
+       ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if r.TLS.ServerName != "127.0.0.1" {
+                       t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName)
+               }
+       }))
+       defer ts.Close()
+
+       c := &Client{Transport: newTLSTransport(t, ts)}
+       if _, err := c.Get(ts.URL); err != nil {
+               t.Fatalf("expected successful TLS connection, got error: %v", err)
+       }
+}
+
+func TestClientWithIncorrectTLSServerName(t *testing.T) {
+       ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+       defer ts.Close()
+
+       trans := newTLSTransport(t, ts)
+       trans.TLSClientConfig.ServerName = "badserver"
+       c := &Client{Transport: trans}
+       _, err := c.Get(ts.URL)
+       if err == nil {
+               t.Fatalf("expected an error")
+       }
+       if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
+               t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
+       }
+}
index 57cf0c9417dbae8c822d28d7dc88a23e14a14e41..165600e52beba14cb12610f98ba73538d5bb0f02 100644 (file)
@@ -184,15 +184,15 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
 // of ASN.1 time).
 var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
-MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
+MIIBTTCB+qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
 DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7
 qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL
-8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw
-DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG
-CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7
-+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM=
------END CERTIFICATE-----
-`)
+8i1UQF6AzwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCACQwEgYDVR0TAQH/BAgwBgEB
+/wIBATANBgNVHQ4EBgQEAQIDBDAPBgNVHSMECDAGgAQBAgMEMBsGA1UdEQQUMBKC
+CTEyNy4wLjAuMYIFWzo6MV0wCwYJKoZIhvcNAQEFA0EAj1Jsn/h2KHy7dgqutZNB
+nCGlNN+8vw263Bax9MklR85Ti6a0VWSvp/fDQZUADvmFTDkcXeA24pqmdUxeQDWw
+Pg==
+-----END CERTIFICATE-----`)
 
 // localhostKey is the private key for localhostCert.
 var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
index fe6318824e0b13ae54e7d878b0f9e5b526ceb46b..a33d787f25d958aab58d99cba405844743b045a6 100644 (file)
@@ -379,7 +379,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
 
        if cm.targetScheme == "https" {
                // Initiate TLS and check remote host name against certificate.
-               conn = tls.Client(conn, t.TLSClientConfig)
+               cfg := t.TLSClientConfig
+               if cfg == nil || cfg.ServerName == "" {
+                       host, _, _ := net.SplitHostPort(cm.addr())
+                       if cfg == nil {
+                               cfg = &tls.Config{ServerName: host}
+                       } else {
+                               clone := *cfg // shallow clone
+                               clone.ServerName = host
+                               cfg = &clone
+                       }
+               }
+               conn = tls.Client(conn, cfg)
                if err = conn.(*tls.Conn).Handshake(); err != nil {
                        return nil, err
                }