"context"
"crypto/rand"
"crypto/tls"
+ "crypto/x509"
"encoding/binary"
"errors"
"fmt"
}
}
+func TestTransportEventTraceTLSVerify(t *testing.T) {
+ var mu sync.Mutex
+ var buf bytes.Buffer
+ logf := func(format string, args ...interface{}) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&buf, format, args...)
+ buf.WriteByte('\n')
+ }
+
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Error("Unexpected request")
+ }))
+ defer ts.Close()
+
+ certpool := x509.NewCertPool()
+ certpool.AddCert(ts.Certificate())
+
+ c := &Client{Transport: &Transport{
+ TLSClientConfig: &tls.Config{
+ ServerName: "dns-is-faked.golang",
+ RootCAs: certpool,
+ },
+ }}
+
+ trace := &httptrace.ClientTrace{
+ TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
+ TLSHandshakeDone: func(s tls.ConnectionState, err error) {
+ logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
+ },
+ }
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
+ _, err := c.Do(req)
+ if err == nil {
+ t.Error("Expected request to fail TLS verification")
+ }
+
+ mu.Lock()
+ got := buf.String()
+ mu.Unlock()
+
+ wantOnce := func(sub string) {
+ if strings.Count(got, sub) != 1 {
+ t.Errorf("expected substring %q exactly once in output.", sub)
+ }
+ }
+
+ wantOnce("TLSHandshakeStart")
+ wantOnce("TLSHandshakeDone")
+ wantOnce("err = x509: certificate is valid for example.com")
+
+ if t.Failed() {
+ t.Errorf("Output:\n%s", got)
+ }
+}
+
var (
isDNSHijackedOnce sync.Once
isDNSHijacked bool