m.activeCertHandles = nil
}
+ if ch, ok := m.(*clientHelloMsg); ok {
+ // extensions is special cased, as it is only populated by the
+ // server-side of a handshake and is not expected to roundtrip
+ // through marshal + unmarshal. m ends up with the list of
+ // extensions necessary to serialize the other fields of
+ // clientHelloMsg, so check that it is non-empty, then clear it.
+ if len(ch.extensions) == 0 {
+ t.Errorf("expected ch.extensions to be populated on unmarshal")
+ }
+ ch.extensions = nil
+ }
+
// clientHelloMsg and serverHelloMsg, when unmarshalled, store
// their original representation, for later use in the handshake
// transcript. In order to prevent DeepEqual from failing since
"runtime"
"slices"
"strings"
+ "sync/atomic"
"testing"
"time"
)
runServerTestTLS12(t, test)
}
+// TestHandshakeServerGetCertificateExtensions tests to make sure that the
+// Extensions passed to GetCertificate match what we expect based on the
+// clientHelloMsg
+func TestHandshakeServerGetCertificateExtensions(t *testing.T) {
+ const errMsg = "TestHandshakeServerGetCertificateExtensions error"
+ // ensure the test condition inside our GetCertificate callback
+ // is actually invoked
+ var called atomic.Int32
+
+ testVersions := []uint16{VersionTLS12, VersionTLS13}
+ for _, vers := range testVersions {
+ t.Run(fmt.Sprintf("TLS version %04x", vers), func(t *testing.T) {
+ pk, _ := ecdh.X25519().GenerateKey(rand.Reader)
+ clientHello := &clientHelloMsg{
+ vers: vers,
+ random: make([]byte, 32),
+ cipherSuites: []uint16{TLS_CHACHA20_POLY1305_SHA256},
+ compressionMethods: []uint8{compressionNone},
+ serverName: "test",
+ keyShares: []keyShare{{group: X25519, data: pk.PublicKey().Bytes()}},
+ supportedCurves: []CurveID{X25519},
+ supportedSignatureAlgorithms: []SignatureScheme{Ed25519},
+ }
+
+ // the clientHelloMsg initialized just above is serialized with
+ // two extensions: server_name(0) and application_layer_protocol_negotiation(16)
+ expectedExtensions := []uint16{
+ extensionServerName,
+ extensionSupportedCurves,
+ extensionSignatureAlgorithms,
+ extensionKeyShare,
+ }
+
+ if vers == VersionTLS13 {
+ clientHello.supportedVersions = []uint16{VersionTLS13}
+ expectedExtensions = append(expectedExtensions, extensionSupportedVersions)
+ }
+
+ // Go's TLS client presents extensions in the ClientHello sorted by extension ID
+ slices.Sort(expectedExtensions)
+
+ serverConfig := testConfig.Clone()
+ serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
+ if !slices.Equal(expectedExtensions, clientHello.Extensions) {
+ t.Errorf("expected extensions on ClientHelloInfo (%v) to match clientHelloMsg (%v)", expectedExtensions, clientHello.Extensions)
+ }
+ called.Add(1)
+
+ return nil, errors.New(errMsg)
+ }
+ testClientHelloFailure(t, serverConfig, clientHello, errMsg)
+ })
+ }
+
+ if int(called.Load()) != len(testVersions) {
+ t.Error("expected our GetCertificate test to be called twice")
+ }
+}
+
// TestHandshakeServerSNIGetCertificateError tests to make sure that errors in
// GetCertificate result in a tls alert.
func TestHandshakeServerSNIGetCertificateError(t *testing.T) {