errServerTemporarilyMisbehaving = errors.New("server misbehaving")
)
-func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
+func newRequest(q dnsmessage.Question, ad bool) (id uint16, udpReq, tcpReq []byte, err error) {
id = uint16(randInt())
- b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
+ b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true, AuthenticData: ad})
if err := b.StartQuestions(); err != nil {
return 0, nil, nil, err
}
}
// exchange sends a query on the connection and hopes for a response.
-func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP bool) (dnsmessage.Parser, dnsmessage.Header, error) {
+func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP, ad bool) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
- id, udpReq, tcpReq, err := newRequest(q)
+ id, udpReq, tcpReq, err := newRequest(q, ad)
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
}
for j := uint32(0); j < sLen; j++ {
server := cfg.servers[(serverOffset+j)%sLen]
- p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP)
+ p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP, cfg.trustAD)
if err != nil {
dnsErr := &DNSError{
Err: err.Error(),
for _, tt := range dnsTransportFallbackTests {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second, useUDPOrTCP)
+ _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second, useUDPOrTCP, false)
if err != nil {
t.Error(err)
continue
for _, tt := range specialDomainNameTests {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second, useUDPOrTCP)
+ _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second, useUDPOrTCP, false)
if err != nil {
t.Error(err)
continue
}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
ctx := context.Background()
- _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useUDPOrTCP)
+ _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useUDPOrTCP, false)
if err != nil {
t.Fatal("exhange failed:", err)
}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useTCPOnly)
+ _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useTCPOnly, false)
if err != nil {
t.Fatal("exchange failed:", err)
}
}
}
}
+
+func TestDNSTrustAD(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ if q.Questions[0].Name.String() == "notrustad.go.dev." && q.Header.AuthenticData {
+ t.Error("unexpected AD bit")
+ }
+
+ if q.Questions[0].Name.String() == "trustad.go.dev." && !q.Header.AuthenticData {
+ t.Error("expected AD bit")
+ }
+
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ if q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+ }
+
+ return r, nil
+ }}
+
+ r := &Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ err = conf.writeAndUpdate([]string{"nameserver 127.0.0.1"})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := r.LookupIPAddr(context.Background(), "notrustad.go.dev"); err != nil {
+ t.Errorf("lookup failed: %v", err)
+ }
+
+ err = conf.writeAndUpdate([]string{"nameserver 127.0.0.1", "options trust-ad"})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := r.LookupIPAddr(context.Background(), "trustad.go.dev"); err != nil {
+ t.Errorf("lookup failed: %v", err)
+ }
+}