]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: pool Conn's outBuf to reduce memory cost of idle connections
authorcch123 <buaa.cch@gmail.com>
Mon, 19 Oct 2020 12:55:46 +0000 (12:55 +0000)
committerFilippo Valsorda <filippo@golang.org>
Mon, 9 Nov 2020 13:12:41 +0000 (13:12 +0000)
Derived from CL 263277, which includes benchmarks.

Fixes #42035

Co-authored-by: Filippo Valsorda <filippo@golang.org>
Change-Id: I5f28673f95d4568b7d13dbc20e9d4b48d481a93d
Reviewed-on: https://go-review.googlesource.com/c/go/+/267957
Run-TryBot: Dmitri Shuralyov <dmitshur@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Filippo Valsorda <filippo@golang.org>
Trust: Roland Shoemaker <roland@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Roberto Clapis <roberto@golang.org>
src/crypto/tls/conn.go

index ada19d6e7ae6e26b0c2940306bc98648fe1b70dc..b9a1095862a7754e32ce567262b29610594f1ed7 100644 (file)
@@ -94,7 +94,6 @@ type Conn struct {
        rawInput  bytes.Buffer // raw input, starting with a record header
        input     bytes.Reader // application data waiting to be read, from rawInput.Next
        hand      bytes.Buffer // handshake data waiting to be read
-       outBuf    []byte       // scratch buffer used by out.encrypt
        buffering bool         // whether records are buffered in sendBuf
        sendBuf   []byte       // a buffer of records waiting to be sent
 
@@ -928,9 +927,28 @@ func (c *Conn) flush() (int, error) {
        return n, err
 }
 
+// outBufPool pools the record-sized scratch buffers used by writeRecordLocked.
+var outBufPool = sync.Pool{
+       New: func() interface{} {
+               return new([]byte)
+       },
+}
+
 // writeRecordLocked writes a TLS record with the given type and payload to the
 // connection and updates the record layer state.
 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
+       outBufPtr := outBufPool.Get().(*[]byte)
+       outBuf := *outBufPtr
+       defer func() {
+               // You might be tempted to simplify this by just passing &outBuf to Put,
+               // but that would make the local copy of the outBuf slice header escape
+               // to the heap, causing an allocation. Instead, we keep around the
+               // pointer to the slice header returned by Get, which is already on the
+               // heap, and overwrite and return that.
+               *outBufPtr = outBuf
+               outBufPool.Put(outBufPtr)
+       }()
+
        var n int
        for len(data) > 0 {
                m := len(data)
@@ -938,8 +956,8 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
                        m = maxPayload
                }
 
-               _, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen)
-               c.outBuf[0] = byte(typ)
+               _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
+               outBuf[0] = byte(typ)
                vers := c.vers
                if vers == 0 {
                        // Some TLS servers fail if the record version is
@@ -950,17 +968,17 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
                        // See RFC 8446, Section 5.1.
                        vers = VersionTLS12
                }
-               c.outBuf[1] = byte(vers >> 8)
-               c.outBuf[2] = byte(vers)
-               c.outBuf[3] = byte(m >> 8)
-               c.outBuf[4] = byte(m)
+               outBuf[1] = byte(vers >> 8)
+               outBuf[2] = byte(vers)
+               outBuf[3] = byte(m >> 8)
+               outBuf[4] = byte(m)
 
                var err error
-               c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand())
+               outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
                if err != nil {
                        return n, err
                }
-               if _, err := c.write(c.outBuf); err != nil {
+               if _, err := c.write(outBuf); err != nil {
                        return n, err
                }
                n += m