]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add *Config argument to Dial
authorRuss Cox <rsc@golang.org>
Tue, 7 Dec 2010 21:15:15 +0000 (16:15 -0500)
committerRuss Cox <rsc@golang.org>
Tue, 7 Dec 2010 21:15:15 +0000 (16:15 -0500)
Document undocumented exported names.
Allow nil Rand, Time, RootCAs in Config.

Fixes #1248.

R=agl1
CC=golang-dev
https://golang.org/cl/3481042

src/pkg/crypto/tls/ca_set.go
src/pkg/crypto/tls/common.go
src/pkg/crypto/tls/handshake_client.go
src/pkg/crypto/tls/handshake_server.go
src/pkg/crypto/tls/tls.go
src/pkg/http/client.go
src/pkg/websocket/client.go

index fe2a540f4dbc5e768a8b2a3c60b2c6adcb42c5fe..ae00ac5586831549343fe18bff967a9767a9f7b5 100644 (file)
@@ -16,6 +16,7 @@ type CASet struct {
        byName         map[string][]*x509.Certificate
 }
 
+// NewCASet returns a new, empty CASet.
 func NewCASet() *CASet {
        return &CASet{
                make(map[string][]*x509.Certificate),
index a4f2b804f10621367ce9bcd170c9e1c9efd31519..4fb17ad3a89182624a63cbab8c3d0abc7a6e9ac8 100644 (file)
@@ -78,6 +78,7 @@ const (
        // Rest of these are reserved by the TLS spec
 )
 
+// ConnectionState records basic TLS details about the connection.
 type ConnectionState struct {
        HandshakeComplete  bool
        CipherSuite        uint16
@@ -88,28 +89,65 @@ type ConnectionState struct {
 // has been passed to a TLS function it must not be modified.
 type Config struct {
        // Rand provides the source of entropy for nonces and RSA blinding.
+       // If Rand is nil, TLS uses the cryptographic random reader in package
+       // crypto/rand.
        Rand io.Reader
+
        // Time returns the current time as the number of seconds since the epoch.
+       // If Time is nil, TLS uses the system time.Seconds.
        Time func() int64
-       // Certificates contains one or more certificate chains.
+
+       // Certificates contains one or more certificate chains
+       // to present to the other side of the connection.
+       // Server configurations must include at least one certificate.
        Certificates []Certificate
-       RootCAs      *CASet
+
+       // RootCAs defines the set of root certificate authorities
+       // that clients use when verifying server certificates.
+       // If RootCAs is nil, TLS uses the host's root CA set.
+       RootCAs *CASet
+
        // NextProtos is a list of supported, application level protocols.
        // Currently only server-side handling is supported.
        NextProtos []string
+
        // ServerName is included in the client's handshake to support virtual
        // hosting.
        ServerName string
-       // AuthenticateClient determines if a server will request a certificate
+
+       // AuthenticateClient controls whether a server will request a certificate
        // from the client. It does not require that the client send a
-       // certificate nor, if it does, that the certificate is anything more
-       // than self-signed.
+       // certificate nor does it require that the certificate sent be
+       // anything more than self-signed.
        AuthenticateClient bool
 }
 
+func (c *Config) rand() io.Reader {
+       r := c.Rand
+       if r == nil {
+               return rand.Reader
+       }
+       return r
+}
+
+func (c *Config) time() int64 {
+       t := c.Time
+       if t == nil {
+               t = time.Seconds
+       }
+       return t()
+}
+
+func (c *Config) rootCAs() *CASet {
+       s := c.RootCAs
+       if s == nil {
+               s = defaultRoots()
+       }
+       return s
+}
+
+// A Certificate is a chain of one or more certificates, leaf first.
 type Certificate struct {
-       // Certificate contains a chain of one or more certificates. Leaf
-       // certificate first.
        Certificate [][]byte
        PrivateKey  *rsa.PrivateKey
 }
@@ -143,14 +181,10 @@ func mutualVersion(vers uint16) (uint16, bool) {
        return vers, true
 }
 
-// The defaultConfig is used in place of a nil *Config in the TLS server and client.
-var varDefaultConfig *Config
-
-var once sync.Once
+var emptyConfig Config
 
 func defaultConfig() *Config {
-       once.Do(initDefaultConfig)
-       return varDefaultConfig
+       return &emptyConfig
 }
 
 // Possible certificate files; stop after finding one.
@@ -162,7 +196,16 @@ var certFiles = []string{
        "/usr/share/curl/curl-ca-bundle.crt", // OS X
 }
 
-func initDefaultConfig() {
+var once sync.Once
+
+func defaultRoots() *CASet {
+       once.Do(initDefaultRoots)
+       return varDefaultRoots
+}
+
+var varDefaultRoots *CASet
+
+func initDefaultRoots() {
        roots := NewCASet()
        for _, file := range certFiles {
                data, err := ioutil.ReadFile(file)
@@ -171,10 +214,5 @@ func initDefaultConfig() {
                        break
                }
        }
-
-       varDefaultConfig = &Config{
-               Rand:    rand.Reader,
-               Time:    time.Seconds,
-               RootCAs: roots,
-       }
+       varDefaultRoots = roots
 }
index b6b0e0fad3784f609c3ec1d3db88b02f2b5715a0..4cddba3303079355679525cc9b270952dbc12880 100644 (file)
@@ -30,12 +30,12 @@ func (c *Conn) clientHandshake() os.Error {
                serverName:         c.config.ServerName,
        }
 
-       t := uint32(c.config.Time())
+       t := uint32(c.config.time())
        hello.random[0] = byte(t >> 24)
        hello.random[1] = byte(t >> 16)
        hello.random[2] = byte(t >> 8)
        hello.random[3] = byte(t)
-       _, err := io.ReadFull(c.config.Rand, hello.random[4:])
+       _, err := io.ReadFull(c.config.rand(), hello.random[4:])
        if err != nil {
                c.sendAlert(alertInternalError)
                return os.ErrorString("short read from Rand")
@@ -217,12 +217,12 @@ func (c *Conn) clientHandshake() os.Error {
        preMasterSecret := make([]byte, 48)
        preMasterSecret[0] = byte(hello.vers >> 8)
        preMasterSecret[1] = byte(hello.vers)
-       _, err = io.ReadFull(c.config.Rand, preMasterSecret[2:])
+       _, err = io.ReadFull(c.config.rand(), preMasterSecret[2:])
        if err != nil {
                return c.sendAlert(alertInternalError)
        }
 
-       ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.Rand, pub, preMasterSecret)
+       ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.rand(), pub, preMasterSecret)
        if err != nil {
                return c.sendAlert(alertInternalError)
        }
@@ -235,7 +235,7 @@ func (c *Conn) clientHandshake() os.Error {
                var digest [36]byte
                copy(digest[0:16], finishedHash.serverMD5.Sum())
                copy(digest[16:36], finishedHash.serverSHA1.Sum())
-               signed, err := rsa.SignPKCS1v15(c.config.Rand, c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:])
+               signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:])
                if err != nil {
                        return c.sendAlert(alertInternalError)
                }
index 225503846100c9b12c2f7c9af35e8a0cb3faedba..6db2a6a1bf61f97dbeb1caf72b23089fb04b3b23 100644 (file)
@@ -84,13 +84,13 @@ func (c *Conn) serverHandshake() os.Error {
 
        hello.vers = vers
        hello.cipherSuite = suite.id
-       t := uint32(config.Time())
+       t := uint32(config.time())
        hello.random = make([]byte, 32)
        hello.random[0] = byte(t >> 24)
        hello.random[1] = byte(t >> 16)
        hello.random[2] = byte(t >> 8)
        hello.random[3] = byte(t)
-       _, err = io.ReadFull(config.Rand, hello.random[4:])
+       _, err = io.ReadFull(config.rand(), hello.random[4:])
        if err != nil {
                return c.sendAlert(alertInternalError)
        }
@@ -209,12 +209,12 @@ func (c *Conn) serverHandshake() os.Error {
        }
 
        preMasterSecret := make([]byte, 48)
-       _, err = io.ReadFull(config.Rand, preMasterSecret[2:])
+       _, err = io.ReadFull(config.rand(), preMasterSecret[2:])
        if err != nil {
                return c.sendAlert(alertInternalError)
        }
 
-       err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret)
+       err = rsa.DecryptPKCS1v15SessionKey(config.rand(), config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret)
        if err != nil {
                return c.sendAlert(alertHandshakeFailure)
        }
index 61f0a9702dc8af96c44d79107f366ab9ab5fa29a..b11d3225daa6eb7add50aa6deee219d36a8b6614 100644 (file)
@@ -15,19 +15,31 @@ import (
        "strings"
 )
 
+// Server returns a new TLS server side connection
+// using conn as the underlying transport.
+// The configuration config must be non-nil and must have
+// at least one certificate.
 func Server(conn net.Conn, config *Config) *Conn {
        return &Conn{conn: conn, config: config}
 }
 
+// Client returns a new TLS client side connection
+// using conn as the underlying transport.
+// Client interprets a nil configuration as equivalent to
+// the zero configuration; see the documentation of Config
+// for the defaults.
 func Client(conn net.Conn, config *Config) *Conn {
        return &Conn{conn: conn, config: config, isClient: true}
 }
 
+// A Listener implements a network listener (net.Listener) for TLS connections.
 type Listener struct {
        listener net.Listener
        config   *Config
 }
 
+// Accept waits for and returns the next incoming TLS connection.
+// The returned connection c is a *tls.Conn.
 func (l *Listener) Accept() (c net.Conn, err os.Error) {
        c, err = l.listener.Accept()
        if err != nil {
@@ -37,8 +49,10 @@ func (l *Listener) Accept() (c net.Conn, err os.Error) {
        return
 }
 
+// Close closes the listener.
 func (l *Listener) Close() os.Error { return l.listener.Close() }
 
+// Addr returns the listener's network address.
 func (l *Listener) Addr() net.Addr { return l.listener.Addr() }
 
 // NewListener creates a Listener which accepts connections from an inner
@@ -52,7 +66,11 @@ func NewListener(listener net.Listener, config *Config) (l *Listener) {
        return
 }
 
-func Listen(network, laddr string, config *Config) (net.Listener, os.Error) {
+// Listen creates a TLS listener accepting connections on the
+// given network address using net.Listen.
+// The configuration config must be non-nil and must have
+// at least one certificate.
+func Listen(network, laddr string, config *Config) (*Listener, os.Error) {
        if config == nil || len(config.Certificates) == 0 {
                return nil, os.NewError("tls.Listen: no certificates in configuration")
        }
@@ -63,7 +81,13 @@ func Listen(network, laddr string, config *Config) (net.Listener, os.Error) {
        return NewListener(l, config), nil
 }
 
-func Dial(network, laddr, raddr string) (net.Conn, os.Error) {
+// Dial connects to the given network address using net.Dial
+// and then initiates a TLS handshake, returning the resulting
+// TLS connection.
+// Dial interprets a nil configuration as equivalent to
+// the zero configuration; see the documentation of Config
+// for the defaults.
+func Dial(network, laddr, raddr string, config *Config) (*Conn, os.Error) {
        c, err := net.Dial(network, laddr, raddr)
        if err != nil {
                return nil, err
@@ -75,15 +99,21 @@ func Dial(network, laddr, raddr string) (net.Conn, os.Error) {
        }
        hostname := raddr[:colonPos]
 
-       config := defaultConfig()
-       config.ServerName = hostname
+       if config == nil {
+               config = defaultConfig()
+       }
+       if config.ServerName != "" {
+               // Make a copy to avoid polluting argument or default.
+               c := *config
+               c.ServerName = hostname
+               config = &c
+       }
        conn := Client(c, config)
-       err = conn.Handshake()
-       if err == nil {
-               return conn, nil
+       if err = conn.Handshake(); err != nil {
+               c.Close()
+               return nil, err
        }
-       c.Close()
-       return nil, err
+       return conn, nil
 }
 
 // LoadX509KeyPair reads and parses a public/private key pair from a pair of
index e902369e7c2d86976c7d9ae3cac2f38eb0ad1095..29678ee32aefcfa1faca99b727fae08ae287214a 100644 (file)
@@ -63,7 +63,7 @@ func send(req *Request) (resp *Response, err os.Error) {
                        return nil, err
                }
        } else { // https
-               conn, err = tls.Dial("tcp", "", addr)
+               conn, err = tls.Dial("tcp", "", addr, nil)
                if err != nil {
                        return nil, err
                }
index caf63f16f657d845a07dcc6a683c9d5cc5f53c85..09134594405b7572597174e0b2a9a6b5a9bdf643 100644 (file)
@@ -111,7 +111,7 @@ func Dial(url, protocol, origin string) (ws *Conn, err os.Error) {
                client, err = net.Dial("tcp", "", parsedUrl.Host)
 
        case "wss":
-               client, err = tls.Dial("tcp", "", parsedUrl.Host)
+               client, err = tls.Dial("tcp", "", parsedUrl.Host, nil)
 
        default:
                err = ErrBadScheme