]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: help linker remove code when only Client or Server is used
authorBrad Fitzpatrick <brad@danga.com>
Wed, 15 Apr 2020 16:06:34 +0000 (09:06 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 15 Apr 2020 19:49:43 +0000 (19:49 +0000)
This saves 166 KiB for a tls.Dial hello world program (5382441 to
5212356 to bytes), by permitting the linker to remove TLS server code.

Change-Id: I16610b836bb0802b7d84995ff881d79ec03b6a84
Reviewed-on: https://go-review.googlesource.com/c/go/+/228111
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/crypto/tls/conn.go
src/crypto/tls/link_test.go [new file with mode: 0644]
src/crypto/tls/tls.go

index eeab030ecaeb52f50686506f154b6f435d3f31e1..6bda73e0857f8df013212f5dfe380e415cc50c99 100644 (file)
@@ -24,8 +24,9 @@ import (
 // It implements the net.Conn interface.
 type Conn struct {
        // constant
-       conn     net.Conn
-       isClient bool
+       conn        net.Conn
+       isClient    bool
+       handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
 
        // handshakeStatus is 1 if the connection is currently transferring
        // application data (i.e. is not currently processing a handshake).
@@ -1349,11 +1350,7 @@ func (c *Conn) Handshake() error {
        c.in.Lock()
        defer c.in.Unlock()
 
-       if c.isClient {
-               c.handshakeErr = c.clientHandshake()
-       } else {
-               c.handshakeErr = c.serverHandshake()
-       }
+       c.handshakeErr = c.handshakeFn()
        if c.handshakeErr == nil {
                c.handshakes++
        } else {
diff --git a/src/crypto/tls/link_test.go b/src/crypto/tls/link_test.go
new file mode 100644 (file)
index 0000000..c1fb57e
--- /dev/null
@@ -0,0 +1,121 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+       "bytes"
+       "internal/testenv"
+       "io/ioutil"
+       "os"
+       "os/exec"
+       "path/filepath"
+       "testing"
+)
+
+// Tests that the linker is able to remove references to the Client or Server if unused.
+func TestLinkerGC(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping in short mode")
+       }
+       t.Parallel()
+       goBin := testenv.GoToolPath(t)
+       testenv.MustHaveGoBuild(t)
+
+       tests := []struct {
+               name    string
+               program string
+               want    []string
+               bad     []string
+       }{
+               {
+                       name: "empty_import",
+                       program: `package main
+import _ "crypto/tls"
+func main() {}
+`,
+                       bad: []string{
+                               "tls.(*Conn)",
+                               "type.crypto/tls.clientHandshakeState",
+                               "type.crypto/tls.serverHandshakeState",
+                       },
+               },
+               {
+                       name: "only_conn",
+                       program: `package main
+import "crypto/tls"
+var c = new(tls.Conn)
+func main() {}
+`,
+                       want: []string{"tls.(*Conn)"},
+                       bad: []string{
+                               "type.crypto/tls.clientHandshakeState",
+                               "type.crypto/tls.serverHandshakeState",
+                       },
+               },
+               {
+                       name: "client_and_server",
+                       program: `package main
+import "crypto/tls"
+func main() {
+  tls.Dial("", "", nil)
+  tls.Server(nil, nil)
+}
+`,
+                       want: []string{
+                               "crypto/tls.(*Conn).clientHandshake",
+                               "crypto/tls.(*Conn).serverHandshake",
+                       },
+               },
+               {
+                       name: "only_client",
+                       program: `package main
+import "crypto/tls"
+func main() { tls.Dial("", "", nil) }
+`,
+                       want: []string{
+                               "crypto/tls.(*Conn).clientHandshake",
+                       },
+                       bad: []string{
+                               "crypto/tls.(*Conn).serverHandshake",
+                       },
+               },
+               // TODO: add only_server like func main() { tls.Server(nil, nil) }
+               // That currently brings in the client via Conn.handleRenegotiation.
+
+       }
+       tmpDir := t.TempDir()
+       goFile := filepath.Join(tmpDir, "x.go")
+       exeFile := filepath.Join(tmpDir, "x.exe")
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if err := ioutil.WriteFile(goFile, []byte(tt.program), 0644); err != nil {
+                               t.Fatal(err)
+                       }
+                       os.Remove(exeFile)
+                       cmd := exec.Command(goBin, "build", "-o", "x.exe", "x.go")
+                       cmd.Dir = tmpDir
+                       if out, err := cmd.CombinedOutput(); err != nil {
+                               t.Fatalf("compile: %v, %s", err, out)
+                       }
+
+                       cmd = exec.Command(goBin, "tool", "nm", "x.exe")
+                       cmd.Dir = tmpDir
+                       nm, err := cmd.CombinedOutput()
+                       if err != nil {
+                               t.Fatalf("nm: %v, %s", err, nm)
+                       }
+                       for _, sym := range tt.want {
+                               if !bytes.Contains(nm, []byte(sym)) {
+                                       t.Errorf("expected symbol %q not found", sym)
+                               }
+                       }
+                       for _, sym := range tt.bad {
+                               if bytes.Contains(nm, []byte(sym)) {
+                                       t.Errorf("unexpected symbol %q found", sym)
+                               }
+                       }
+               })
+       }
+}
index af44485f44ebb758ad7e41c18656fa2cf24a44de..d98abdaea15a959df5e1e63db7f238365b5cb572 100644 (file)
@@ -32,7 +32,12 @@ import (
 // The configuration config must be non-nil and must include
 // at least one certificate or else set GetCertificate.
 func Server(conn net.Conn, config *Config) *Conn {
-       return &Conn{conn: conn, config: config}
+       c := &Conn{
+               conn:   conn,
+               config: config,
+       }
+       c.handshakeFn = c.serverHandshake
+       return c
 }
 
 // Client returns a new TLS client side connection
@@ -40,7 +45,13 @@ func Server(conn net.Conn, config *Config) *Conn {
 // The config cannot be nil: users must set either ServerName or
 // InsecureSkipVerify in the config.
 func Client(conn net.Conn, config *Config) *Conn {
-       return &Conn{conn: conn, config: config, isClient: true}
+       c := &Conn{
+               conn:     conn,
+               config:   config,
+               isClient: true,
+       }
+       c.handshakeFn = c.clientHandshake
+       return c
 }
 
 // A listener implements a network listener (net.Listener) for TLS connections.