// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
+ handshakeErr error // error resulting from handshake
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
clientProtocol string
clientProtocolFallback bool
- // first permanent error
- connErr
-
// input/output
in, out halfConn // in.Mutex < out.Mutex
rawInput *block // raw input, right off the wire
tmp [16]byte
}
-type connErr struct {
- mu sync.Mutex
- value error
-}
-
-func (e *connErr) setError(err error) error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.value == nil {
- e.value = err
- }
- return err
-}
-
-func (e *connErr) error() error {
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.value
-}
-
// Access to net.Conn methods.
// Cannot just embed net.Conn because that would
// export the struct field too.
// connection, either sending or receiving.
type halfConn struct {
sync.Mutex
+
+ err error // first permanent error
version uint16 // protocol version
cipher interface{} // cipher algorithm
mac macFunction
inDigestBuf, outDigestBuf []byte
}
+func (hc *halfConn) setErrorLocked(err error) error {
+ hc.err = err
+ return err
+}
+
+func (hc *halfConn) error() error {
+ hc.Lock()
+ err := hc.err
+ hc.Unlock()
+ return err
+}
+
// prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use.
func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
switch want {
default:
c.sendAlert(alertInternalError)
- return errors.New("tls: unknown record type requested")
+ return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.handshakeComplete {
c.sendAlert(alertInternalError)
- return errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete")
+ return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete"))
}
case recordTypeApplicationData:
if !c.handshakeComplete {
c.sendAlert(alertInternalError)
- return errors.New("tls: application data record requested before handshake complete")
+ return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete"))
}
}
// err = io.ErrUnexpectedEOF
// }
if e, ok := err.(net.Error); !ok || !e.Temporary() {
- c.setError(err)
+ c.in.setErrorLocked(err)
}
return err
}
// an SSLv2 client.
if want == recordTypeHandshake && typ == 0x80 {
c.sendAlert(alertProtocolVersion)
- return errors.New("tls: unsupported SSLv2 handshake received")
+ return c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
}
vers := uint16(b.data[1])<<8 | uint16(b.data[2])
n := int(b.data[3])<<8 | int(b.data[4])
if c.haveVers && vers != c.vers {
c.sendAlert(alertProtocolVersion)
- return fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers)
+ return c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
}
if n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
- return fmt.Errorf("tls: oversized record received with length %d", n)
+ return c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
}
if !c.haveVers {
// First message, be extra suspicious:
// it's probably not real.
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
c.sendAlert(alertUnexpectedMessage)
- return fmt.Errorf("tls: first record does not look like a TLS handshake")
+ return c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
}
}
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
err = io.ErrUnexpectedEOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
- c.setError(err)
+ c.in.setErrorLocked(err)
}
return err
}
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
ok, off, err := c.in.decrypt(b)
if !ok {
- return c.sendAlert(err)
+ c.in.setErrorLocked(c.sendAlert(err))
}
b.off = off
data := b.data[b.off:]
if len(data) > maxPlaintext {
- c.sendAlert(alertRecordOverflow)
+ err := c.sendAlert(alertRecordOverflow)
c.in.freeBlock(b)
- return c.error()
+ return c.in.setErrorLocked(err)
}
switch typ {
default:
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert:
if len(data) != 2 {
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if alert(data[1]) == alertCloseNotify {
- c.setError(io.EOF)
+ c.in.setErrorLocked(io.EOF)
break
}
switch data[0] {
c.in.freeBlock(b)
goto Again
case alertLevelError:
- c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
+ c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
default:
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
case recordTypeChangeCipherSpec:
if typ != want || len(data) != 1 || data[0] != 1 {
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
err := c.in.changeCipherSpec()
if err != nil {
- c.sendAlert(err.(alert))
+ c.in.setErrorLocked(c.sendAlert(err.(alert)))
}
case recordTypeApplicationData:
if typ != want {
- c.sendAlert(alertUnexpectedMessage)
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
c.input = b
case recordTypeHandshake:
// TODO(rsc): Should at least pick off connection close.
if typ != want {
- return c.sendAlert(alertNoRenegotiation)
+ return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
}
c.hand.Write(data)
}
if b != nil {
c.in.freeBlock(b)
}
- return c.error()
+ return c.in.err
}
// sendAlert sends a TLS alert message.
c.writeRecord(recordTypeAlert, c.tmp[0:2])
// closeNotify is a special case in that it isn't an error:
if err != alertCloseNotify {
- return c.setError(&net.OpError{Op: "local error", Err: err})
+ return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
return nil
}
c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2])
- return n, c.setError(&net.OpError{Op: "local error", Err: err})
+ return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
}
return
// c.in.Mutex < L; c.out.Mutex < L.
func (c *Conn) readHandshake() (interface{}, error) {
for c.hand.Len() < 4 {
- if err := c.error(); err != nil {
+ if err := c.in.err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
- c.sendAlert(alertInternalError)
- return nil, c.error()
+ return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
for c.hand.Len() < 4+n {
- if err := c.error(); err != nil {
+ if err := c.in.err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
case typeFinished:
m = new(finishedMsg)
default:
- c.sendAlert(alertUnexpectedMessage)
- return nil, alertUnexpectedMessage
+ return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// The handshake message unmarshallers
data = append([]byte(nil), data...)
if !m.unmarshal(data) {
- c.sendAlert(alertUnexpectedMessage)
- return nil, alertUnexpectedMessage
+ return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
return m, nil
}
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
- if err := c.error(); err != nil {
- return 0, err
- }
-
if err := c.Handshake(); err != nil {
- return 0, c.setError(err)
+ return 0, err
}
c.out.Lock()
defer c.out.Unlock()
+ if err := c.out.err; err != nil {
+ return 0, err
+ }
+
if !c.handshakeComplete {
return 0, alertInternalError
}
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
n, err := c.writeRecord(recordTypeApplicationData, b[:1])
if err != nil {
- return n, c.setError(err)
+ return n, c.out.setErrorLocked(err)
}
m, b = 1, b[1:]
}
}
n, err := c.writeRecord(recordTypeApplicationData, b)
- return n + m, c.setError(err)
+ return n + m, c.out.setErrorLocked(err)
}
// Read can be made to time out and return a net.Error with Timeout() == true
// CBC IV. So this loop ignores a limited number of empty records.
const maxConsecutiveEmptyRecords = 100
for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
- for c.input == nil && c.error() == nil {
+ for c.input == nil && c.in.err == nil {
if err := c.readRecord(recordTypeApplicationData); err != nil {
// Soft error, like EAGAIN
return 0, err
}
}
- if err := c.error(); err != nil {
+ if err := c.in.err; err != nil {
return 0, err
}
func (c *Conn) Handshake() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
- if err := c.error(); err != nil {
+ if err := c.handshakeErr; err != nil {
return err
}
if c.handshakeComplete {
return nil
}
+
if c.isClient {
- return c.clientHandshake()
+ c.handshakeErr = c.clientHandshake()
+ } else {
+ c.handshakeErr = c.serverHandshake()
}
- return c.serverHandshake()
+ return c.handshakeErr
}
// ConnectionState returns basic TLS details about the connection.