]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add support for additional alpn flags to bogo_shim_test
authorClide Stefani <cstefani.sites@gmail.com>
Tue, 25 Jun 2024 19:52:32 +0000 (15:52 -0400)
committerGopher Robot <gobot@golang.org>
Wed, 14 Aug 2024 18:04:16 +0000 (18:04 +0000)
The existing implementation of bogo_shim_test does not support tests
that use the -expect-advertised-alpn flag or the -select-alpn flag. This
change allows bogo_shim_test to receive and enforce these flags.

Support for these flags is added in the same change because these flags are set together.

Updates #51434

Change-Id: Ia37f9e7403d4a43e6da68c16039a4bcb56ebd032
Reviewed-on: https://go-review.googlesource.com/c/go/+/595655
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Reviewed-by: Russell Webb <russell.webb@protonmail.com>
Reviewed-by: Clide Stefani <cstefani.sites@gmail.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
src/crypto/tls/bogo_shim_test.go

index ce01852aeeda30320af7666f260bd2a71b742612..ff836d93ed42bbcb97f00cbd3f8254c86e1624ee 100644 (file)
@@ -17,9 +17,12 @@ import (
        "os/exec"
        "path/filepath"
        "runtime"
+       "slices"
        "strconv"
        "strings"
        "testing"
+
+       "golang.org/x/crypto/cryptobyte"
 )
 
 var (
@@ -77,10 +80,12 @@ var (
        _                       = flag.Bool("expect-ticket-supports-early-data", false, "")
        onResumeShimWritesFirst = flag.Bool("on-resume-shim-writes-first", false, "")
 
-       advertiseALPN = flag.String("advertise-alpn", "", "")
-       expectALPN    = flag.String("expect-alpn", "", "")
-       rejectALPN    = flag.Bool("reject-alpn", false, "")
-       declineALPN   = flag.Bool("decline-alpn", false, "")
+       advertiseALPN        = flag.String("advertise-alpn", "", "")
+       expectALPN           = flag.String("expect-alpn", "", "")
+       rejectALPN           = flag.Bool("reject-alpn", false, "")
+       declineALPN          = flag.Bool("decline-alpn", false, "")
+       expectAdvertisedALPN = flag.String("expect-advertised-alpn", "", "")
+       selectALPN           = flag.String("select-alpn", "", "")
 
        hostName = flag.String("host-name", "", "")
 
@@ -118,6 +123,29 @@ func bogoShim() {
                MaxVersion: uint16(*maxVersion),
 
                ClientSessionCache: NewLRUClientSessionCache(0),
+
+               GetConfigForClient: func(chi *ClientHelloInfo) (*Config, error) {
+
+                       if *expectAdvertisedALPN != "" {
+
+                               s := cryptobyte.String(*expectAdvertisedALPN)
+
+                               var expectedALPNs []string
+
+                               for !s.Empty() {
+                                       var alpn cryptobyte.String
+                                       if !s.ReadUint8LengthPrefixed(&alpn) {
+                                               return nil, fmt.Errorf("unexpected error while parsing arguments for -expect-advertised-alpn")
+                                       }
+                                       expectedALPNs = append(expectedALPNs, string(alpn))
+                               }
+
+                               if !slices.Equal(chi.SupportedProtos, expectedALPNs) {
+                                       return nil, fmt.Errorf("unexpected ALPN: got %q, want %q", chi.SupportedProtos, expectedALPNs)
+                               }
+                       }
+                       return nil, nil
+               },
        }
 
        if *noTLS1 {
@@ -160,6 +188,9 @@ func bogoShim() {
        if *declineALPN {
                cfg.NextProtos = []string{}
        }
+       if *selectALPN != "" {
+               cfg.NextProtos = []string{*selectALPN}
+       }
 
        if *hostName != "" {
                cfg.ServerName = *hostName
@@ -288,6 +319,11 @@ func bogoShim() {
                        if *expectALPN != "" && cs.NegotiatedProtocol != *expectALPN {
                                log.Fatalf("unexpected protocol negotiated: want %q, got %q", *expectALPN, cs.NegotiatedProtocol)
                        }
+
+                       if *selectALPN != "" && cs.NegotiatedProtocol != *selectALPN {
+                               log.Fatalf("unexpected protocol negotiated: want %q, got %q", *selectALPN, cs.NegotiatedProtocol)
+                       }
+
                        if *expectVersion != 0 && cs.Version != uint16(*expectVersion) {
                                log.Fatalf("expected ssl version %q, got %q", uint16(*expectVersion), cs.Version)
                        }