"context", "math/rand", "os", "reflect", "sort", "syscall", "time",
"internal/nettrace", "internal/poll",
"internal/syscall/windows", "internal/singleflight", "internal/race",
- "golang_org/x/net/lif", "golang_org/x/net/route",
+ "golang_org/x/net/dns/dnsmessage", "golang_org/x/net/lif", "golang_org/x/net/route",
},
// NET enables use of basic network-related packages.
import (
"math/rand"
"sort"
+
+ "golang_org/x/net/dns/dnsmessage"
)
// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
return string(buf), nil
}
-// Find answer for name in dns message.
-// On return, if err == nil, addrs != nil.
-func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err error) {
- addrs = make([]dnsRR, 0, len(dns.answer))
-
- if dns.rcode == dnsRcodeNameError {
- return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
- }
- if dns.rcode != dnsRcodeSuccess {
- // None of the error codes make sense
- // for the query we sent. If we didn't get
- // a name error and we didn't get success,
- // the server is behaving incorrectly or
- // having temporary trouble.
- err := &DNSError{Err: "server misbehaving", Name: name, Server: server}
- if dns.rcode == dnsRcodeServerFailure {
- err.IsTemporary = true
- }
- return "", nil, err
- }
-
- // Look for the name.
- // Presotto says it's okay to assume that servers listed in
- // /etc/resolv.conf are recursive resolvers.
- // We asked for recursion, so it should have included
- // all the answers we need in this one packet.
-Cname:
- for cnameloop := 0; cnameloop < 10; cnameloop++ {
- addrs = addrs[0:0]
- for _, rr := range dns.answer {
- if _, justHeader := rr.(*dnsRR_Header); justHeader {
- // Corrupt record: we only have a
- // header. That header might say it's
- // of type qtype, but we don't
- // actually have it. Skip.
- continue
- }
- h := rr.Header()
- if h.Class == dnsClassINET && equalASCIILabel(h.Name, name) {
- switch h.Rrtype {
- case qtype:
- addrs = append(addrs, rr)
- case dnsTypeCNAME:
- // redirect to cname
- name = rr.(*dnsRR_CNAME).Cname
- continue Cname
- }
- }
- }
- if len(addrs) == 0 {
- return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
- }
- return name, addrs, nil
- }
-
- return "", nil, &DNSError{Err: "too many redirects", Name: name, Server: server}
-}
-
-func equalASCIILabel(x, y string) bool {
- if len(x) != len(y) {
+func equalASCIIName(x, y dnsmessage.Name) bool {
+ if x.Length != y.Length {
return false
}
- for i := 0; i < len(x); i++ {
- a := x[i]
- b := y[i]
+ for i := 0; i < int(x.Length); i++ {
+ a := x.Data[i]
+ b := y.Data[i]
if 'A' <= a && a <= 'Z' {
a += 0x20
}
func TestWeighting(t *testing.T) {
testWeighting(t, 0.05)
}
-
-// Issue 8434: verify that Temporary returns true on an error when rcode
-// is SERVFAIL
-func TestIssue8434(t *testing.T) {
- msg := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- rcode: dnsRcodeServerFailure,
- },
- }
-
- _, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
- if err == nil {
- t.Fatal("expected an error")
- }
- if ne, ok := err.(Error); !ok {
- t.Fatalf("err = %#v; wanted something supporting net.Error", err)
- } else if !ne.Temporary() {
- t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
- }
- if de, ok := err.(*DNSError); !ok {
- t.Fatalf("err = %#v; wanted a *net.DNSError", err)
- } else if !de.IsTemporary {
- t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
- }
-}
-
-// Issue 12778: verify that NXDOMAIN without RA bit errors as
-// "no such host" and not "server misbehaving"
-func TestIssue12778(t *testing.T) {
- msg := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- rcode: dnsRcodeNameError,
- recursion_available: false,
- },
- }
-
- _, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
- if err == nil {
- t.Fatal("expected an error")
- }
- de, ok := err.(*DNSError)
- if !ok {
- t.Fatalf("err = %#v; wanted a *net.DNSError", err)
- }
- if de.Err != errNoSuchHost.Error() {
- t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
- }
-}
"os"
"sync"
"time"
-)
-
-// A dnsConn represents a DNS transport endpoint.
-type dnsConn interface {
- io.Closer
- SetDeadline(time.Time) error
+ "golang_org/x/net/dns/dnsmessage"
+)
- // dnsRoundTrip executes a single DNS transaction, returning a
- // DNS response message for the provided DNS query message.
- dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
+func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
+ id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
+ b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
+ b.EnableCompression()
+ if err := b.StartQuestions(); err != nil {
+ return 0, nil, nil, err
+ }
+ if err := b.Question(q); err != nil {
+ return 0, nil, nil, err
+ }
+ tcpReq, err = b.Finish()
+ udpReq = tcpReq[2:]
+ l := len(tcpReq) - 2
+ tcpReq[0] = byte(l >> 8)
+ tcpReq[1] = byte(l)
+ return id, udpReq, tcpReq, err
}
-// dnsPacketConn implements the dnsConn interface for RFC 1035's
-// "UDP usage" transport mechanism. Conn is a packet-oriented connection,
-// such as a *UDPConn.
-type dnsPacketConn struct {
- Conn
+func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
+ if !respHdr.Response {
+ return false
+ }
+ if reqID != respHdr.ID {
+ return false
+ }
+ if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
+ return false
+ }
+ return true
}
-func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
- b, ok := query.Pack()
- if !ok {
- return nil, errors.New("cannot marshal DNS message")
- }
+func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, 512) // see RFC 1035
for {
n, err := c.Read(b)
if err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
- resp := &dnsMsg{}
- if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
- // Ignore invalid responses as they may be malicious
- // forgery attempts. Instead continue waiting until
- // timeout. See golang.org/issue/13281.
+ var p dnsmessage.Parser
+ // Ignore invalid responses as they may be malicious
+ // forgery attempts. Instead continue waiting until
+ // timeout. See golang.org/issue/13281.
+ h, err := p.Start(b[:n])
+ if err != nil {
continue
}
- return resp, nil
+ q, err := p.Question()
+ if err != nil || !checkResponse(id, query, h, q) {
+ continue
+ }
+ return p, h, nil
}
}
-// dnsStreamConn implements the dnsConn interface for RFC 1035's
-// "TCP usage" transport mechanism. Conn is a stream-oriented connection,
-// such as a *TCPConn.
-type dnsStreamConn struct {
- Conn
-}
-
-func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
- b, ok := query.Pack()
- if !ok {
- return nil, errors.New("cannot marshal DNS message")
- }
- l := len(b)
- b = append([]byte{byte(l >> 8), byte(l)}, b...)
+func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
- l = int(b[0])<<8 | int(b[1])
+ l := int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
n, err := io.ReadFull(c, b[:l])
if err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
- resp := &dnsMsg{}
- if !resp.Unpack(b[:n]) {
- return nil, errors.New("cannot unmarshal DNS message")
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message")
}
- if !resp.IsResponseTo(query) {
- return nil, errors.New("invalid DNS response")
+ q, err := p.Question()
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message")
}
- return resp, nil
+ if !checkResponse(id, query, h, q) {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response")
+ }
+ return p, h, nil
}
// exchange sends a query on the connection and hopes for a response.
-func (r *Resolver) exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
- out := dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- recursion_desired: true,
- },
- question: []dnsQuestion{
- {name, qtype, dnsClassINET},
- },
+func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+ q.Class = dnsmessage.ClassINET
+ id, udpReq, tcpReq, err := newRequest(q)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot marshal DNS message")
}
for _, network := range []string{"udp", "tcp"} {
- // TODO(mdempsky): Refactor so defers from UDP-based
- // exchanges happen before TCP-based exchange.
-
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
c, err := r.dial(ctx, network, server)
if err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
- defer c.Close()
if d, ok := ctx.Deadline(); ok && !d.IsZero() {
c.SetDeadline(d)
}
- out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
- in, err := c.dnsRoundTrip(&out)
+ var p dnsmessage.Parser
+ var h dnsmessage.Header
+ if network == "tcp" {
+ p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
+ } else {
+ p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
+ }
+ c.Close()
if err != nil {
- return nil, mapErr(err)
+ return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
+ }
+ if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response")
}
- if in.truncated { // see RFC 5966
+ if h.Truncated { // see RFC 5966
continue
}
- return in, nil
+ return p, h, nil
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("no answer from DNS server")
+}
+
+func checkHeaders(p *dnsmessage.Parser, h dnsmessage.Header, name, server string) error {
+ _, err := p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ return &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+
+ // libresolv continues to the next server when it receives
+ // an invalid referral response. See golang.org/issue/15434.
+ if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
+ return &DNSError{Err: "lame referral", Name: name, Server: server}
+ }
+
+ // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
+ // it means the response in msg was not useful and trying another
+ // server probably won't help. Return now in those cases.
+ // TODO: indicate this in a more obvious way, such as a field on DNSError?
+ if h.RCode == dnsmessage.RCodeNameError {
+ return &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
+ }
+ if h.RCode != dnsmessage.RCodeSuccess {
+ // None of the error codes make sense
+ // for the query we sent. If we didn't get
+ // a name error and we didn't get success,
+ // the server is behaving incorrectly or
+ // having temporary trouble.
+ err := &DNSError{Err: "server misbehaving", Name: name, Server: server}
+ if h.RCode == dnsmessage.RCodeServerFailure {
+ err.IsTemporary = true
+ }
+ return err
+ }
+
+ return nil
+}
+
+func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type, name, server string) error {
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ return &DNSError{
+ Err: errNoSuchHost.Error(),
+ Name: name,
+ Server: server,
+ }
+ }
+ if err != nil {
+ return &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type == qtype {
+ return nil
+ }
+ if err := p.SkipAnswer(); err != nil {
+ return &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
}
- return nil, errors.New("no answer from DNS server")
}
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
-func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
+func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
var lastErr error
serverOffset := cfg.serverOffset()
sLen := uint32(len(cfg.servers))
+ n, err := dnsmessage.NewName(name)
+ if err != nil {
+ return dnsmessage.Parser{}, "", errors.New("cannot marshal DNS message")
+ }
+ q := dnsmessage.Question{
+ Name: n,
+ Type: qtype,
+ Class: dnsmessage.ClassINET,
+ }
+
for i := 0; i < cfg.attempts; i++ {
for j := uint32(0); j < sLen; j++ {
server := cfg.servers[(serverOffset+j)%sLen]
- msg, err := r.exchange(ctx, server, name, qtype, cfg.timeout)
+ p, h, err := r.exchange(ctx, server, q, cfg.timeout)
if err != nil {
lastErr = &DNSError{
Err: err.Error(),
}
continue
}
- // libresolv continues to the next server when it receives
- // an invalid referral response. See golang.org/issue/15434.
- if msg.rcode == dnsRcodeSuccess && !msg.authoritative && !msg.recursion_available && len(msg.answer) == 0 && len(msg.extra) == 0 {
- lastErr = &DNSError{Err: "lame referral", Name: name, Server: server}
+
+ lastErr = checkHeaders(&p, h, name, server)
+ if lastErr != nil {
continue
}
- cname, rrs, err := answer(name, server, msg, qtype)
- // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
- // it means the response in msg was not useful and trying another
- // server probably won't help. Return now in those cases.
- // TODO: indicate this in a more obvious way, such as a field on DNSError?
- if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError {
- return cname, rrs, err
+
+ lastErr = skipToAnswer(&p, qtype, name, server)
+ if lastErr == nil {
+ return p, server, nil
}
- lastErr = err
}
}
- return "", nil, lastErr
-}
-
-// addrRecordList converts and returns a list of IP addresses from DNS
-// address records (both A and AAAA). Other record types are ignored.
-func addrRecordList(rrs []dnsRR) []IPAddr {
- addrs := make([]IPAddr, 0, 4)
- for _, rr := range rrs {
- switch rr := rr.(type) {
- case *dnsRR_A:
- addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
- case *dnsRR_AAAA:
- ip := make(IP, IPv6len)
- copy(ip, rr.AAAA[:])
- addrs = append(addrs, IPAddr{IP: ip})
- }
- }
- return addrs
+ return dnsmessage.Parser{}, "", lastErr
}
// A resolverConfig represents a DNS stub resolver configuration.
<-conf.ch
}
-func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
+func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
if !isDomainName(name) {
// We used to use "invalid domain name" as the error,
// but that is a detail of the specific lookup mechanism.
// Other lookups might allow broader name syntax
// (for example Multicast DNS allows UTF-8; see RFC 6762).
// For consistency with libc resolvers, report no such host.
- return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name}
+ return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
+ var (
+ p dnsmessage.Parser
+ server string
+ err error
+ )
for _, fqdn := range conf.nameList(name) {
- cname, rrs, err = r.tryOneName(ctx, conf, fqdn, qtype)
+ p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
if err == nil {
break
}
break
}
}
+ if err == nil {
+ return p, server, nil
+ }
if err, ok := err.(*DNSError); ok {
// Show original name passed to lookup, not suffixed one.
// In general we might have tried many suffixes; showing
// just one is misleading. See also golang.org/issue/6324.
err.Name = name
}
- return
+ return dnsmessage.Parser{}, "", err
}
// avoidDNS reports whether this is a hostname for which we should not
return
}
-func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) {
+func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname dnsmessage.Name, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
addrs = goLookupIPFiles(name)
if len(addrs) > 0 || order == hostLookupFiles {
- return addrs, name, nil
+ return addrs, dnsmessage.Name{}, nil
}
}
if !isDomainName(name) {
// See comment in func lookup above about use of errNoSuchHost.
- return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
+ return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
type racer struct {
- cname string
- rrs []dnsRR
+ p dnsmessage.Parser
+ server string
error
}
lane := make(chan racer, 1)
- qtypes := [...]uint16{dnsTypeA, dnsTypeAAAA}
+ qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
var lastErr error
for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes {
dnsWaitGroup.Add(1)
- go func(qtype uint16) {
- defer dnsWaitGroup.Done()
- cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype)
- lane <- racer{cname, rrs, err}
+ go func(qtype dnsmessage.Type) {
+ p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
+ lane <- racer{p, server, err}
+ dnsWaitGroup.Done()
}(qtype)
}
hitStrictError := false
}
continue
}
- addrs = append(addrs, addrRecordList(racer.rrs)...)
- if cname == "" {
- cname = racer.cname
+
+ // Presotto says it's okay to assume that servers listed in
+ // /etc/resolv.conf are recursive resolvers.
+ //
+ // We asked for recursion, so it should have included all the
+ // answers we need in this one packet.
+ //
+ // Further, RFC 1035 section 4.3.1 says that "the recursive
+ // response to a query will be... The answer to the query,
+ // possibly preface by one or more CNAME RRs that specify
+ // aliases encountered on the way to an answer."
+ //
+ // Therefore, we should be able to assume that we can ignore
+ // CNAMEs and that the A and AAAA records we requested are
+ // for the canonical name.
+
+ loop:
+ for {
+ h, err := racer.p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: racer.server,
+ }
+ }
+ if err != nil {
+ break
+ }
+ switch h.Type {
+ case dnsmessage.TypeA:
+ a, err := racer.p.AResource()
+ if err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: racer.server,
+ }
+ break loop
+ }
+ addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
+
+ case dnsmessage.TypeAAAA:
+ aaaa, err := racer.p.AAAAResource()
+ if err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: racer.server,
+ }
+ break loop
+ }
+ addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
+
+ default:
+ if err := racer.p.SkipAnswer(); err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: racer.server,
+ }
+ break loop
+ }
+ continue
+ }
+ if cname.Length == 0 && h.Name.Length != 0 {
+ cname = h.Name
+ }
}
}
if hitStrictError {
addrs = goLookupIPFiles(name)
}
if len(addrs) == 0 && lastErr != nil {
- return nil, "", lastErr
+ return nil, dnsmessage.Name{}, lastErr
}
}
return addrs, cname, nil
}
// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
-func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (cname string, err error) {
+func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (string, error) {
order := systemConf().hostLookupOrder(host)
- _, cname, err = r.goLookupIPCNAMEOrder(ctx, host, order)
- return
+ _, cname, err := r.goLookupIPCNAMEOrder(ctx, host, order)
+ return cname.String(), err
}
// goLookupPTR is the native Go implementation of LookupAddr.
if err != nil {
return nil, err
}
- _, rrs, err := r.lookup(ctx, arpa, dnsTypePTR)
+ p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR)
if err != nil {
return nil, err
}
- ptrs := make([]string, len(rrs))
- for i, rr := range rrs {
- ptrs[i] = rr.(*dnsRR_PTR).Ptr
+ var ptrs []string
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: addr,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypePTR {
+ continue
+ }
+ ptr, err := p.PTRResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: addr,
+ Server: server,
+ }
+ }
+ ptrs = append(ptrs, ptr.PTR.String())
+
}
return ptrs, nil
}
"sync"
"testing"
"time"
+
+ "golang_org/x/net/dns/dnsmessage"
)
var goResolver = Resolver{PreferGo: true}
// Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
-const TestAddr uint32 = 0xc0000201
+var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01}
// Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation.
var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+func mustNewName(name string) dnsmessage.Name {
+ nn, err := dnsmessage.NewName(name)
+ if err != nil {
+ panic(fmt.Sprint("creating name: ", err))
+ }
+ return nn
+}
+
+func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question {
+ return dnsmessage.Question{
+ Name: mustNewName(name),
+ Type: qtype,
+ Class: class,
+ }
+}
+
var dnsTransportFallbackTests = []struct {
- server string
- name string
- qtype uint16
- timeout int
- rcode int
+ server string
+ question dnsmessage.Question
+ timeout int
+ rcode dnsmessage.RCode
}{
// Querying "com." with qtype=255 usually makes an answer
// which requires more than 512 bytes.
- {"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess},
- {"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess},
+ {"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess},
+ {"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess},
}
func TestDNSTransportFallback(t *testing.T) {
fake := fakeDNSServer{
- rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
- rcode: dnsRcodeSuccess,
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
},
- question: q.question,
+ Questions: q.Questions,
}
if n == "udp" {
- r.truncated = true
+ r.Header.Truncated = true
}
return r, nil
},
for _, tt := range dnsTransportFallbackTests {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second)
+ _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second)
if err != nil {
t.Error(err)
continue
}
- switch msg.rcode {
- case tt.rcode:
- default:
- t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
+ if h.RCode != tt.rcode {
+ t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode)
continue
}
}
// See RFC 6761 for further information about the reserved, pseudo
// domain names.
var specialDomainNameTests = []struct {
- name string
- qtype uint16
- rcode int
+ question dnsmessage.Question
+ rcode dnsmessage.RCode
}{
// Name resolution APIs and libraries should not recognize the
// followings as special.
- {"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError},
- {"test.", dnsTypeALL, dnsRcodeNameError},
- {"example.com.", dnsTypeALL, dnsRcodeSuccess},
+ {mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess},
// Name resolution APIs and libraries should recognize the
// followings as special and should not send any queries.
// Though, we test those names here for verifying negative
// answers at DNS query-response interaction level.
- {"localhost.", dnsTypeALL, dnsRcodeNameError},
- {"invalid.", dnsTypeALL, dnsRcodeNameError},
+ {mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
}
func TestSpecialDomainName(t *testing.T) {
- fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
- switch q.question[0].Name {
+ switch q.Questions[0].Name.String() {
case "example.com.":
- r.rcode = dnsRcodeSuccess
+ r.Header.RCode = dnsmessage.RCodeSuccess
default:
- r.rcode = dnsRcodeNameError
+ r.Header.RCode = dnsmessage.RCodeNameError
}
return r, nil
for _, tt := range specialDomainNameTests {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second)
+ _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second)
if err != nil {
t.Error(err)
continue
}
- switch msg.rcode {
- case tt.rcode, dnsRcodeServerFailure:
- default:
- t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
+ if h.RCode != tt.rcode {
+ t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode)
continue
}
}
}
}
-var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
- }
- if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA {
- r.answer = []dnsRR{
- &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ Questions: q.Questions,
+ }
+ if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
},
}
}
func TestGoLookupIPWithResolverConfig(t *testing.T) {
defer dnsWaitGroup.Wait()
-
- fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
switch s {
case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
break
default:
time.Sleep(10 * time.Millisecond)
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
}
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
- for _, question := range q.question {
- switch question.Qtype {
- case dnsTypeA:
- switch question.Name {
+ for _, question := range q.Questions {
+ switch question.Type {
+ case dnsmessage.TypeA:
+ switch question.Name.String() {
case "hostname.as112.net.":
break
case "ipv4.google.com.":
- r.answer = append(r.answer, &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
})
default:
}
- case dnsTypeAAAA:
- switch question.Name {
+ case dnsmessage.TypeAAAA:
+ switch question.Name.String() {
case "hostname.as112.net.":
break
case "ipv6.google.com.":
- r.answer = append(r.answer, &dnsRR_AAAA{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeAAAA,
- Class: dnsClassINET,
- Rdlength: 16,
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: TestAddr6,
},
- AAAA: TestAddr6,
})
}
}
func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
defer dnsWaitGroup.Wait()
- fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
return r, nil
}}
t.Fatal(err)
}
- fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
- switch q.question[0].Name {
+ switch q.Questions[0].Name.String() {
case fqdn + ".servfail.":
- r.rcode = dnsRcodeServerFailure
+ r.Header.RCode = dnsmessage.RCodeServerFailure
default:
- r.rcode = dnsRcodeNameError
+ r.Header.RCode = dnsmessage.RCodeNameError
}
return r, nil
t.Fatal(err)
}
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
if s == "192.0.2.2:53" {
- r.recursion_available = true
- if q.question[0].Qtype == dnsTypeA {
- r.answer = []dnsRR{
- &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ r.Header.RecursionAvailable = true
+ if q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
},
}
}
}
type fakeDNSServer struct {
- rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
+ rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
}
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
- return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
+ tcp := n == "tcp" || n == "tcp4" || n == "tcp6"
+ return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil
}
type fakeDNSConn struct {
Conn
+ tcp bool
server *fakeDNSServer
n string
s string
- q *dnsMsg
+ q dnsmessage.Message
t time.Time
+ buf []byte
}
func (f *fakeDNSConn) Close() error {
}
func (f *fakeDNSConn) Read(b []byte) (int, error) {
+ if len(f.buf) > 0 {
+ n := copy(b, f.buf)
+ f.buf = f.buf[n:]
+ return n, nil
+ }
+
resp, err := f.server.rh(f.n, f.s, f.q, f.t)
if err != nil {
return 0, err
}
- bb, ok := resp.Pack()
- if !ok {
- return 0, errors.New("cannot marshal DNS message")
+ bb := make([]byte, 2, 514)
+ bb, err = resp.AppendPack(bb)
+ if err != nil {
+ return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
}
+
+ if f.tcp {
+ l := len(bb) - 2
+ bb[0] = byte(l >> 8)
+ bb[1] = byte(l)
+ f.buf = bb
+ return f.Read(b)
+ }
+
+ bb = bb[2:]
if len(b) < len(bb) {
return 0, errors.New("read would fragment DNS message")
}
}
func (f *fakeDNSConn) Write(b []byte) (int, error) {
- f.q = new(dnsMsg)
- if !f.q.Unpack(b) {
- return 0, errors.New("cannot unmarshal DNS message")
+ if f.tcp && len(b) >= 2 {
+ b = b[2:]
+ }
+ if f.q.Unpack(b) != nil {
+ return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
}
return len(b), nil
}
return
}
- msg := &dnsMsg{}
- if !msg.Unpack(b[:n]) {
- t.Error("invalid DNS query")
+ var msg dnsmessage.Message
+ if msg.Unpack(b[:n]) != nil {
+ t.Error("invalid DNS query:", err)
return
}
s.Write([]byte("garbage DNS response packet"))
- msg.response = true
- msg.id++ // make invalid ID
- b, ok := msg.Pack()
- if !ok {
- t.Error("failed to pack DNS response")
+ msg.Header.Response = true
+ msg.Header.ID++ // make invalid ID
+
+ if b, err = msg.Pack(); err != nil {
+ t.Error("failed to pack DNS response:", err)
return
}
s.Write(b)
- msg.id-- // restore original ID
- msg.answer = []dnsRR{
- &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: "www.example.com.",
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ msg.Header.ID-- // restore original ID
+ msg.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: mustNewName("www.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
},
}
- b, ok = msg.Pack()
- if !ok {
- t.Error("failed to pack DNS response")
+ b, err = msg.Pack()
+ if err != nil {
+ t.Error("failed to pack DNS response:", err)
return
}
s.Write(b)
}()
- msg := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: 42,
},
- question: []dnsQuestion{
+ Questions: []dnsmessage.Question{
{
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
+ Name: mustNewName("www.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
},
},
}
- dc := &dnsPacketConn{c}
- resp, err := dc.dnsRoundTrip(msg)
+ b, err := msg.Pack()
if err != nil {
- t.Fatalf("dnsRoundTripUDP failed: %v", err)
+ t.Fatal("Pack failed:", err)
}
- if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
+ p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b)
+ if err != nil {
+ t.Fatalf("dnsPacketRoundTrip failed: %v", err)
+ }
+
+ p.SkipAllQuestions()
+ as, err := p.AllAnswers()
+ if err != nil {
+ t.Fatal("AllAnswers failed:", err)
+ }
+ if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr {
t.Errorf("got address %v, want %v", got, TestAddr)
}
}
var deadline0 time.Time
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q, deadline)
if deadline.IsZero() {
if s == "192.0.2.1:53" {
deadline0 = deadline
time.Sleep(10 * time.Millisecond)
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
}
if deadline.Equal(deadline0) {
}
var usedServers []string
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
usedServers = append(usedServers, s)
return mockTXTResponse(q), nil
}}
}
}
-func mockTXTResponse(q *dnsMsg) *dnsMsg {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
- recursion_available: true,
+func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RecursionAvailable: true,
},
- question: q.question,
- answer: []dnsRR{
- &dnsRR_TXT{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeTXT,
- Class: dnsClassINET,
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeTXT,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"ok"},
},
- Txt: "ok",
},
},
}
cases := []struct {
desc string
- resolveWhich func(quest *dnsQuestion) resolveWhichEnum
+ resolveWhich func(quest dnsmessage.Question) resolveWhichEnum
wantStrictErr error
wantLaxErr error
wantIPs []string
}{
{
desc: "No errors",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
return resolveOK
},
wantIPs: []string{ip4, ip6},
},
{
desc: "searchX error fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchX {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX {
return resolveTimeout
}
return resolveOK
},
{
desc: "searchX IPv4-only timeout fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchX && quest.Qtype == dnsTypeA {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA {
return resolveTimeout
}
return resolveOK
},
{
desc: "searchX IPv6-only servfail fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchX && quest.Qtype == dnsTypeAAAA {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA {
return resolveServfail
}
return resolveOK
},
{
desc: "searchY error always fails",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchY {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY {
return resolveTimeout
}
return resolveOK
},
{
desc: "searchY IPv4-only socket error fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchY && quest.Qtype == dnsTypeA {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA {
return resolveOpError
}
return resolveOK
},
{
desc: "searchY IPv6-only timeout fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchY && quest.Qtype == dnsTypeAAAA {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA {
return resolveTimeout
}
return resolveOK
}
for i, tt := range cases {
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
- switch tt.resolveWhich(&q.question[0]) {
+ switch tt.resolveWhich(q.Questions[0]) {
case resolveOK:
// Handle below.
case resolveOpError:
- return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
+ return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
case resolveServfail:
- return &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
- rcode: dnsRcodeServerFailure,
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeServerFailure,
},
- question: q.question,
+ Questions: q.Questions,
}, nil
case resolveTimeout:
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
default:
t.Fatal("Impossible resolveWhich")
}
- switch q.question[0].Name {
+ switch q.Questions[0].Name.String() {
case searchX, name + ".":
// Return NXDOMAIN to utilize the search list.
- return &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
- rcode: dnsRcodeNameError,
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeNameError,
},
- question: q.question,
+ Questions: q.Questions,
}, nil
case searchY:
// Return records below.
default:
- return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
}
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
- switch q.question[0].Qtype {
- case dnsTypeA:
- r.answer = []dnsRR{
- &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ switch q.Questions[0].Type {
+ case dnsmessage.TypeA:
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
},
}
- case dnsTypeAAAA:
- r.answer = []dnsRR{
- &dnsRR_AAAA{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeAAAA,
- Class: dnsClassINET,
- Rdlength: 16,
+ case dnsmessage.TypeAAAA:
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: TestAddr6,
},
- AAAA: TestAddr6,
},
}
default:
- return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype)
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type)
}
return r, nil
}}
const searchY = "test.y.golang.org."
const txt = "Hello World"
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
- switch q.question[0].Name {
+ switch q.Questions[0].Name.String() {
case searchX:
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
case searchY:
return mockTXTResponse(q), nil
default:
- return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
}
}}
for _, strict := range []bool{true, false} {
r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
- _, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT)
+ p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT)
var wantErr error
var wantRRs int
if strict {
if !reflect.DeepEqual(err, wantErr) {
t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
}
- if len(rrs) != wantRRs {
- t.Errorf("strict=%v: got %v; want %v", strict, len(rrs), wantRRs)
+ a, err := p.AllAnswers()
+ if err != nil {
+ a = nil
+ }
+ if len(a) != wantRRs {
+ t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs)
}
}
}
func TestDNSGoroutineRace(t *testing.T) {
defer dnsWaitGroup.Wait()
- fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
time.Sleep(10 * time.Microsecond)
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
}}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
t.Fatal("fake DNS lookup unexpectedly succeeded")
}
}
+
+// Issue 8434: verify that Temporary returns true on an error when rcode
+// is SERVFAIL
+func TestIssue8434(t *testing.T) {
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ RCode: dnsmessage.RCodeServerFailure,
+ },
+ }
+ b, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Pack failed:", err)
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b)
+ if err != nil {
+ t.Fatal("Start failed:", err)
+ }
+ if err := p.SkipAllQuestions(); err != nil {
+ t.Fatal("SkipAllQuestions failed:", err)
+ }
+
+ err = checkHeaders(&p, h, "golang.org", "foo:53")
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ if ne, ok := err.(Error); !ok {
+ t.Fatalf("err = %#v; wanted something supporting net.Error", err)
+ } else if !ne.Temporary() {
+ t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
+ }
+ if de, ok := err.(*DNSError); !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ } else if !de.IsTemporary {
+ t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
+ }
+}
+
+// Issue 12778: verify that NXDOMAIN without RA bit errors as
+// "no such host" and not "server misbehaving"
+func TestIssue12778(t *testing.T) {
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ RCode: dnsmessage.RCodeNameError,
+ RecursionAvailable: false,
+ },
+ }
+
+ b, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Pack failed:", err)
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b)
+ if err != nil {
+ t.Fatal("Start failed:", err)
+ }
+ if err := p.SkipAllQuestions(); err != nil {
+ t.Fatal("SkipAllQuestions failed:", err)
+ }
+
+ err = checkHeaders(&p, h, "golang.org", "foo:53")
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ de, ok := err.(*DNSError)
+ if !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ }
+ if de.Err != errNoSuchHost.Error() {
+ t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
+ }
+}
+++ /dev/null
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// DNS packet assembly. See RFC 1035.
-//
-// This is intended to support name resolution during Dial.
-// It doesn't have to be blazing fast.
-//
-// Each message structure has a Walk method that is used by
-// a generic pack/unpack routine. Thus, if in the future we need
-// to define new message structs, no new pack/unpack/printing code
-// needs to be written.
-//
-// The first half of this file defines the DNS message formats.
-// The second half implements the conversion to and from wire format.
-// A few of the structure elements have string tags to aid the
-// generic pack/unpack routines.
-//
-// TODO(rsc): There are enough names defined in this file that they're all
-// prefixed with dns. Perhaps put this in its own package later.
-
-package net
-
-// Packet formats
-
-// Wire constants.
-const (
- // valid dnsRR_Header.Rrtype and dnsQuestion.qtype
- dnsTypeA = 1
- dnsTypeNS = 2
- dnsTypeMD = 3
- dnsTypeMF = 4
- dnsTypeCNAME = 5
- dnsTypeSOA = 6
- dnsTypeMB = 7
- dnsTypeMG = 8
- dnsTypeMR = 9
- dnsTypeNULL = 10
- dnsTypeWKS = 11
- dnsTypePTR = 12
- dnsTypeHINFO = 13
- dnsTypeMINFO = 14
- dnsTypeMX = 15
- dnsTypeTXT = 16
- dnsTypeAAAA = 28
- dnsTypeSRV = 33
-
- // valid dnsQuestion.qtype only
- dnsTypeAXFR = 252
- dnsTypeMAILB = 253
- dnsTypeMAILA = 254
- dnsTypeALL = 255
-
- // valid dnsQuestion.qclass
- dnsClassINET = 1
- dnsClassCSNET = 2
- dnsClassCHAOS = 3
- dnsClassHESIOD = 4
- dnsClassANY = 255
-
- // dnsMsg.rcode
- dnsRcodeSuccess = 0
- dnsRcodeFormatError = 1
- dnsRcodeServerFailure = 2
- dnsRcodeNameError = 3
- dnsRcodeNotImplemented = 4
- dnsRcodeRefused = 5
-)
-
-// A dnsStruct describes how to iterate over its fields to emulate
-// reflective marshaling.
-type dnsStruct interface {
- // Walk iterates over fields of a structure and calls f
- // with a reference to that field, the name of the field
- // and a tag ("", "domain", "ipv4", "ipv6") specifying
- // particular encodings. Possible concrete types
- // for v are *uint16, *uint32, *string, or []byte, and
- // *int, *bool in the case of dnsMsgHdr.
- // Whenever f returns false, Walk must stop and return
- // false, and otherwise return true.
- Walk(f func(v interface{}, name, tag string) (ok bool)) (ok bool)
-}
-
-// The wire format for the DNS packet header.
-type dnsHeader struct {
- Id uint16
- Bits uint16
- Qdcount, Ancount, Nscount, Arcount uint16
-}
-
-func (h *dnsHeader) Walk(f func(v interface{}, name, tag string) bool) bool {
- return f(&h.Id, "Id", "") &&
- f(&h.Bits, "Bits", "") &&
- f(&h.Qdcount, "Qdcount", "") &&
- f(&h.Ancount, "Ancount", "") &&
- f(&h.Nscount, "Nscount", "") &&
- f(&h.Arcount, "Arcount", "")
-}
-
-const (
- // dnsHeader.Bits
- _QR = 1 << 15 // query/response (response=1)
- _AA = 1 << 10 // authoritative
- _TC = 1 << 9 // truncated
- _RD = 1 << 8 // recursion desired
- _RA = 1 << 7 // recursion available
-)
-
-// DNS queries.
-type dnsQuestion struct {
- Name string
- Qtype uint16
- Qclass uint16
-}
-
-func (q *dnsQuestion) Walk(f func(v interface{}, name, tag string) bool) bool {
- return f(&q.Name, "Name", "domain") &&
- f(&q.Qtype, "Qtype", "") &&
- f(&q.Qclass, "Qclass", "")
-}
-
-// DNS responses (resource records).
-// There are many types of messages,
-// but they all share the same header.
-type dnsRR_Header struct {
- Name string
- Rrtype uint16
- Class uint16
- Ttl uint32
- Rdlength uint16 // length of data after header
-}
-
-func (h *dnsRR_Header) Header() *dnsRR_Header {
- return h
-}
-
-func (h *dnsRR_Header) Walk(f func(v interface{}, name, tag string) bool) bool {
- return f(&h.Name, "Name", "domain") &&
- f(&h.Rrtype, "Rrtype", "") &&
- f(&h.Class, "Class", "") &&
- f(&h.Ttl, "Ttl", "") &&
- f(&h.Rdlength, "Rdlength", "")
-}
-
-type dnsRR interface {
- dnsStruct
- Header() *dnsRR_Header
-}
-
-// Specific DNS RR formats for each query type.
-
-type dnsRR_CNAME struct {
- Hdr dnsRR_Header
- Cname string
-}
-
-func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_CNAME) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.Cname, "Cname", "domain")
-}
-
-type dnsRR_MX struct {
- Hdr dnsRR_Header
- Pref uint16
- Mx string
-}
-
-func (rr *dnsRR_MX) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_MX) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.Pref, "Pref", "") && f(&rr.Mx, "Mx", "domain")
-}
-
-type dnsRR_NS struct {
- Hdr dnsRR_Header
- Ns string
-}
-
-func (rr *dnsRR_NS) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_NS) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.Ns, "Ns", "domain")
-}
-
-type dnsRR_PTR struct {
- Hdr dnsRR_Header
- Ptr string
-}
-
-func (rr *dnsRR_PTR) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_PTR) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.Ptr, "Ptr", "domain")
-}
-
-type dnsRR_SOA struct {
- Hdr dnsRR_Header
- Ns string
- Mbox string
- Serial uint32
- Refresh uint32
- Retry uint32
- Expire uint32
- Minttl uint32
-}
-
-func (rr *dnsRR_SOA) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_SOA) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) &&
- f(&rr.Ns, "Ns", "domain") &&
- f(&rr.Mbox, "Mbox", "domain") &&
- f(&rr.Serial, "Serial", "") &&
- f(&rr.Refresh, "Refresh", "") &&
- f(&rr.Retry, "Retry", "") &&
- f(&rr.Expire, "Expire", "") &&
- f(&rr.Minttl, "Minttl", "")
-}
-
-type dnsRR_TXT struct {
- Hdr dnsRR_Header
- Txt string // not domain name
-}
-
-func (rr *dnsRR_TXT) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_TXT) Walk(f func(v interface{}, name, tag string) bool) bool {
- if !rr.Hdr.Walk(f) {
- return false
- }
- var n uint16 = 0
- for n < rr.Hdr.Rdlength {
- var txt string
- if !f(&txt, "Txt", "") {
- return false
- }
- // more bytes than rr.Hdr.Rdlength said there would be
- if rr.Hdr.Rdlength-n < uint16(len(txt))+1 {
- return false
- }
- n += uint16(len(txt)) + 1
- rr.Txt += txt
- }
- return true
-}
-
-type dnsRR_SRV struct {
- Hdr dnsRR_Header
- Priority uint16
- Weight uint16
- Port uint16
- Target string
-}
-
-func (rr *dnsRR_SRV) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_SRV) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) &&
- f(&rr.Priority, "Priority", "") &&
- f(&rr.Weight, "Weight", "") &&
- f(&rr.Port, "Port", "") &&
- f(&rr.Target, "Target", "domain")
-}
-
-type dnsRR_A struct {
- Hdr dnsRR_Header
- A uint32
-}
-
-func (rr *dnsRR_A) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_A) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.A, "A", "ipv4")
-}
-
-type dnsRR_AAAA struct {
- Hdr dnsRR_Header
- AAAA [16]byte
-}
-
-func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_AAAA) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(rr.AAAA[:], "AAAA", "ipv6")
-}
-
-// Packing and unpacking.
-//
-// All the packers and unpackers take a (msg []byte, off int)
-// and return (off1 int, ok bool). If they return ok==false, they
-// also return off1==len(msg), so that the next unpacker will
-// also fail. This lets us avoid checks of ok until the end of a
-// packing sequence.
-
-// Map of constructors for each RR wire type.
-var rr_mk = map[int]func() dnsRR{
- dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) },
- dnsTypeMX: func() dnsRR { return new(dnsRR_MX) },
- dnsTypeNS: func() dnsRR { return new(dnsRR_NS) },
- dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) },
- dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) },
- dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) },
- dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) },
- dnsTypeA: func() dnsRR { return new(dnsRR_A) },
- dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) },
-}
-
-// Pack a domain name s into msg[off:].
-// Domain names are a sequence of counted strings
-// split at the dots. They end with a zero-length string.
-func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
- // Add trailing dot to canonicalize name.
- if n := len(s); n == 0 || s[n-1] != '.' {
- s += "."
- }
-
- // Allow root domain.
- if s == "." {
- msg[off] = 0
- off++
- return off, true
- }
-
- // Each dot ends a segment of the name.
- // We trade each dot byte for a length byte.
- // There is also a trailing zero.
- // Check that we have all the space we need.
- tot := len(s) + 1
- if off+tot > len(msg) {
- return len(msg), false
- }
-
- // Emit sequence of counted strings, chopping at dots.
- begin := 0
- for i := 0; i < len(s); i++ {
- if s[i] == '.' {
- if i-begin >= 1<<6 { // top two bits of length must be clear
- return len(msg), false
- }
- if i-begin == 0 {
- return len(msg), false
- }
-
- msg[off] = byte(i - begin)
- off++
-
- for j := begin; j < i; j++ {
- msg[off] = s[j]
- off++
- }
- begin = i + 1
- }
- }
- msg[off] = 0
- off++
- return off, true
-}
-
-// Unpack a domain name.
-// In addition to the simple sequences of counted strings above,
-// domain names are allowed to refer to strings elsewhere in the
-// packet, to avoid repeating common suffixes when returning
-// many entries in a single domain. The pointers are marked
-// by a length byte with the top two bits set. Ignoring those
-// two bits, that byte and the next give a 14 bit offset from msg[0]
-// where we should pick up the trail.
-// Note that if we jump elsewhere in the packet,
-// we return off1 == the offset after the first pointer we found,
-// which is where the next record will start.
-// In theory, the pointers are only allowed to jump backward.
-// We let them jump anywhere and stop jumping after a while.
-func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
- s = ""
- ptr := 0 // number of pointers followed
-Loop:
- for {
- if off >= len(msg) {
- return "", len(msg), false
- }
- c := int(msg[off])
- off++
- switch c & 0xC0 {
- case 0x00:
- if c == 0x00 {
- // end of name
- break Loop
- }
- // literal string
- if off+c > len(msg) {
- return "", len(msg), false
- }
- s += string(msg[off:off+c]) + "."
- off += c
- case 0xC0:
- // pointer to somewhere else in msg.
- // remember location after first ptr,
- // since that's how many bytes we consumed.
- // also, don't follow too many pointers --
- // maybe there's a loop.
- if off >= len(msg) {
- return "", len(msg), false
- }
- c1 := msg[off]
- off++
- if ptr == 0 {
- off1 = off
- }
- if ptr++; ptr > 10 {
- return "", len(msg), false
- }
- off = (c^0xC0)<<8 | int(c1)
- default:
- // 0x80 and 0x40 are reserved
- return "", len(msg), false
- }
- }
- if len(s) == 0 {
- s = "."
- }
- if ptr == 0 {
- off1 = off
- }
- return s, off1, true
-}
-
-// packStruct packs a structure into msg at specified offset off, and
-// returns off1 such that msg[off:off1] is the encoded data.
-func packStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
- ok = any.Walk(func(field interface{}, name, tag string) bool {
- switch fv := field.(type) {
- default:
- println("net: dns: unknown packing type")
- return false
- case *uint16:
- i := *fv
- if off+2 > len(msg) {
- return false
- }
- msg[off] = byte(i >> 8)
- msg[off+1] = byte(i)
- off += 2
- case *uint32:
- i := *fv
- msg[off] = byte(i >> 24)
- msg[off+1] = byte(i >> 16)
- msg[off+2] = byte(i >> 8)
- msg[off+3] = byte(i)
- off += 4
- case []byte:
- n := len(fv)
- if off+n > len(msg) {
- return false
- }
- copy(msg[off:off+n], fv)
- off += n
- case *string:
- s := *fv
- switch tag {
- default:
- println("net: dns: unknown string tag", tag)
- return false
- case "domain":
- off, ok = packDomainName(s, msg, off)
- if !ok {
- return false
- }
- case "":
- // Counted string: 1 byte length.
- if len(s) > 255 || off+1+len(s) > len(msg) {
- return false
- }
- msg[off] = byte(len(s))
- off++
- off += copy(msg[off:], s)
- }
- }
- return true
- })
- if !ok {
- return len(msg), false
- }
- return off, true
-}
-
-// unpackStruct decodes msg[off:]Â into the given structure, and
-// returns off1 such that msg[off:off1] is the encoded data.
-func unpackStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
- ok = any.Walk(func(field interface{}, name, tag string) bool {
- switch fv := field.(type) {
- default:
- println("net: dns: unknown packing type")
- return false
- case *uint16:
- if off+2 > len(msg) {
- return false
- }
- *fv = uint16(msg[off])<<8 | uint16(msg[off+1])
- off += 2
- case *uint32:
- if off+4 > len(msg) {
- return false
- }
- *fv = uint32(msg[off])<<24 | uint32(msg[off+1])<<16 |
- uint32(msg[off+2])<<8 | uint32(msg[off+3])
- off += 4
- case []byte:
- n := len(fv)
- if off+n > len(msg) {
- return false
- }
- copy(fv, msg[off:off+n])
- off += n
- case *string:
- var s string
- switch tag {
- default:
- println("net: dns: unknown string tag", tag)
- return false
- case "domain":
- s, off, ok = unpackDomainName(msg, off)
- if !ok {
- return false
- }
- case "":
- if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
- return false
- }
- n := int(msg[off])
- off++
- b := make([]byte, n)
- for i := 0; i < n; i++ {
- b[i] = msg[off+i]
- }
- off += n
- s = string(b)
- }
- *fv = s
- }
- return true
- })
- if !ok {
- return len(msg), false
- }
- return off, true
-}
-
-// Generic struct printer. Prints fields with tag "ipv4" or "ipv6"
-// as IP addresses.
-func printStruct(any dnsStruct) string {
- s := "{"
- i := 0
- any.Walk(func(val interface{}, name, tag string) bool {
- i++
- if i > 1 {
- s += ", "
- }
- s += name + "="
- switch tag {
- case "ipv4":
- i := *val.(*uint32)
- s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
- case "ipv6":
- i := val.([]byte)
- s += IP(i).String()
- default:
- var i int64
- switch v := val.(type) {
- default:
- // can't really happen.
- s += "<unknown type>"
- return true
- case *string:
- s += *v
- return true
- case []byte:
- s += string(v)
- return true
- case *bool:
- if *v {
- s += "true"
- } else {
- s += "false"
- }
- return true
- case *int:
- i = int64(*v)
- case *uint:
- i = int64(*v)
- case *uint8:
- i = int64(*v)
- case *uint16:
- i = int64(*v)
- case *uint32:
- i = int64(*v)
- case *uint64:
- i = int64(*v)
- case *uintptr:
- i = int64(*v)
- }
- s += itoa(int(i))
- }
- return true
- })
- s += "}"
- return s
-}
-
-// Resource record packer.
-func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
- var off1 int
- // pack twice, once to find end of header
- // and again to find end of packet.
- // a bit inefficient but this doesn't need to be fast.
- // off1 is end of header
- // off2 is end of rr
- off1, ok = packStruct(rr.Header(), msg, off)
- if !ok {
- return len(msg), false
- }
- off2, ok = packStruct(rr, msg, off)
- if !ok {
- return len(msg), false
- }
- // pack a third time; redo header with correct data length
- rr.Header().Rdlength = uint16(off2 - off1)
- packStruct(rr.Header(), msg, off)
- return off2, true
-}
-
-// Resource record unpacker.
-func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) {
- // unpack just the header, to find the rr type and length
- var h dnsRR_Header
- off0 := off
- if off, ok = unpackStruct(&h, msg, off); !ok {
- return nil, len(msg), false
- }
- end := off + int(h.Rdlength)
-
- // make an rr of that type and re-unpack.
- // again inefficient but doesn't need to be fast.
- mk, known := rr_mk[int(h.Rrtype)]
- if !known {
- return &h, end, true
- }
- rr = mk()
- off, ok = unpackStruct(rr, msg, off0)
- if off != end {
- return &h, end, true
- }
- return rr, off, ok
-}
-
-// Usable representation of a DNS packet.
-
-// A manually-unpacked version of (id, bits).
-// This is in its own struct for easy printing.
-type dnsMsgHdr struct {
- id uint16
- response bool
- opcode int
- authoritative bool
- truncated bool
- recursion_desired bool
- recursion_available bool
- rcode int
-}
-
-func (h *dnsMsgHdr) Walk(f func(v interface{}, name, tag string) bool) bool {
- return f(&h.id, "id", "") &&
- f(&h.response, "response", "") &&
- f(&h.opcode, "opcode", "") &&
- f(&h.authoritative, "authoritative", "") &&
- f(&h.truncated, "truncated", "") &&
- f(&h.recursion_desired, "recursion_desired", "") &&
- f(&h.recursion_available, "recursion_available", "") &&
- f(&h.rcode, "rcode", "")
-}
-
-type dnsMsg struct {
- dnsMsgHdr
- question []dnsQuestion
- answer []dnsRR
- ns []dnsRR
- extra []dnsRR
-}
-
-func (dns *dnsMsg) Pack() (msg []byte, ok bool) {
- var dh dnsHeader
-
- // Convert convenient dnsMsg into wire-like dnsHeader.
- dh.Id = dns.id
- dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode)
- if dns.recursion_available {
- dh.Bits |= _RA
- }
- if dns.recursion_desired {
- dh.Bits |= _RD
- }
- if dns.truncated {
- dh.Bits |= _TC
- }
- if dns.authoritative {
- dh.Bits |= _AA
- }
- if dns.response {
- dh.Bits |= _QR
- }
-
- // Prepare variable sized arrays.
- question := dns.question
- answer := dns.answer
- ns := dns.ns
- extra := dns.extra
-
- dh.Qdcount = uint16(len(question))
- dh.Ancount = uint16(len(answer))
- dh.Nscount = uint16(len(ns))
- dh.Arcount = uint16(len(extra))
-
- // Could work harder to calculate message size,
- // but this is far more than we need and not
- // big enough to hurt the allocator.
- msg = make([]byte, 2000)
-
- // Pack it in: header and then the pieces.
- off := 0
- off, ok = packStruct(&dh, msg, off)
- if !ok {
- return nil, false
- }
- for i := 0; i < len(question); i++ {
- off, ok = packStruct(&question[i], msg, off)
- if !ok {
- return nil, false
- }
- }
- for i := 0; i < len(answer); i++ {
- off, ok = packRR(answer[i], msg, off)
- if !ok {
- return nil, false
- }
- }
- for i := 0; i < len(ns); i++ {
- off, ok = packRR(ns[i], msg, off)
- if !ok {
- return nil, false
- }
- }
- for i := 0; i < len(extra); i++ {
- off, ok = packRR(extra[i], msg, off)
- if !ok {
- return nil, false
- }
- }
- return msg[0:off], true
-}
-
-func (dns *dnsMsg) Unpack(msg []byte) bool {
- // Header.
- var dh dnsHeader
- off := 0
- var ok bool
- if off, ok = unpackStruct(&dh, msg, off); !ok {
- return false
- }
- dns.id = dh.Id
- dns.response = (dh.Bits & _QR) != 0
- dns.opcode = int(dh.Bits>>11) & 0xF
- dns.authoritative = (dh.Bits & _AA) != 0
- dns.truncated = (dh.Bits & _TC) != 0
- dns.recursion_desired = (dh.Bits & _RD) != 0
- dns.recursion_available = (dh.Bits & _RA) != 0
- dns.rcode = int(dh.Bits & 0xF)
-
- // Arrays.
- dns.question = make([]dnsQuestion, dh.Qdcount)
- dns.answer = make([]dnsRR, 0, dh.Ancount)
- dns.ns = make([]dnsRR, 0, dh.Nscount)
- dns.extra = make([]dnsRR, 0, dh.Arcount)
-
- var rec dnsRR
-
- for i := 0; i < len(dns.question); i++ {
- off, ok = unpackStruct(&dns.question[i], msg, off)
- if !ok {
- return false
- }
- }
- for i := 0; i < int(dh.Ancount); i++ {
- rec, off, ok = unpackRR(msg, off)
- if !ok {
- return false
- }
- dns.answer = append(dns.answer, rec)
- }
- for i := 0; i < int(dh.Nscount); i++ {
- rec, off, ok = unpackRR(msg, off)
- if !ok {
- return false
- }
- dns.ns = append(dns.ns, rec)
- }
- for i := 0; i < int(dh.Arcount); i++ {
- rec, off, ok = unpackRR(msg, off)
- if !ok {
- return false
- }
- dns.extra = append(dns.extra, rec)
- }
- // if off != len(msg) {
- // println("extra bytes in dns packet", off, "<", len(msg));
- // }
- return true
-}
-
-func (dns *dnsMsg) String() string {
- s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n"
- if len(dns.question) > 0 {
- s += "-- Questions\n"
- for i := 0; i < len(dns.question); i++ {
- s += printStruct(&dns.question[i]) + "\n"
- }
- }
- if len(dns.answer) > 0 {
- s += "-- Answers\n"
- for i := 0; i < len(dns.answer); i++ {
- s += printStruct(dns.answer[i]) + "\n"
- }
- }
- if len(dns.ns) > 0 {
- s += "-- Name servers\n"
- for i := 0; i < len(dns.ns); i++ {
- s += printStruct(dns.ns[i]) + "\n"
- }
- }
- if len(dns.extra) > 0 {
- s += "-- Extra\n"
- for i := 0; i < len(dns.extra); i++ {
- s += printStruct(dns.extra[i]) + "\n"
- }
- }
- return s
-}
-
-// IsResponseTo reports whether m is an acceptable response to query.
-func (m *dnsMsg) IsResponseTo(query *dnsMsg) bool {
- if !m.response {
- return false
- }
- if m.id != query.id {
- return false
- }
- if len(m.question) != len(query.question) {
- return false
- }
- for i, q := range m.question {
- q2 := query.question[i]
- if !equalASCIILabel(q.Name, q2.Name) || q.Qtype != q2.Qtype || q.Qclass != q2.Qclass {
- return false
- }
- }
- return true
-}
+++ /dev/null
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package net
-
-import (
- "encoding/hex"
- "reflect"
- "testing"
-)
-
-func TestStructPackUnpack(t *testing.T) {
- want := dnsQuestion{
- Name: ".",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- }
- buf := make([]byte, 50)
- n, ok := packStruct(&want, buf, 0)
- if !ok {
- t.Fatal("packing failed")
- }
- buf = buf[:n]
- got := dnsQuestion{}
- n, ok = unpackStruct(&got, buf, 0)
- if !ok {
- t.Fatal("unpacking failed")
- }
- if n != len(buf) {
- t.Errorf("unpacked different amount than packed: got n = %d, want = %d", n, len(buf))
- }
- if !reflect.DeepEqual(got, want) {
- t.Errorf("got = %+v, want = %+v", got, want)
- }
-}
-
-func TestDomainNamePackUnpack(t *testing.T) {
- tests := []struct {
- in string
- want string
- ok bool
- }{
- {"", ".", true},
- {".", ".", true},
- {"google..com", "", false},
- {"google.com", "google.com.", true},
- {"google..com.", "", false},
- {"google.com.", "google.com.", true},
- {".google.com.", "", false},
- {"www..google.com.", "", false},
- {"www.google.com.", "www.google.com.", true},
- }
-
- for _, test := range tests {
- buf := make([]byte, 30)
- n, ok := packDomainName(test.in, buf, 0)
- if ok != test.ok {
- t.Errorf("packing of %s: got ok = %t, want = %t", test.in, ok, test.ok)
- continue
- }
- if !test.ok {
- continue
- }
- buf = buf[:n]
- got, n, ok := unpackDomainName(buf, 0)
- if !ok {
- t.Errorf("unpacking for %s failed", test.in)
- continue
- }
- if n != len(buf) {
- t.Errorf(
- "unpacked different amount than packed for %s: got n = %d, want = %d",
- test.in,
- n,
- len(buf),
- )
- }
- if got != test.want {
- t.Errorf("unpacking packing of %s: got = %s, want = %s", test.in, got, test.want)
- }
- }
-}
-
-func TestDNSPackUnpack(t *testing.T) {
- want := dnsMsg{
- question: []dnsQuestion{{
- Name: ".",
- Qtype: dnsTypeAAAA,
- Qclass: dnsClassINET,
- }},
- answer: []dnsRR{},
- ns: []dnsRR{},
- extra: []dnsRR{},
- }
- b, ok := want.Pack()
- if !ok {
- t.Fatal("packing failed")
- }
- var got dnsMsg
- ok = got.Unpack(b)
- if !ok {
- t.Fatal("unpacking failed")
- }
- if !reflect.DeepEqual(got, want) {
- t.Errorf("got = %+v, want = %+v", got, want)
- }
-}
-
-func TestDNSParseSRVReply(t *testing.T) {
- data, err := hex.DecodeString(dnsSRVReply)
- if err != nil {
- t.Fatal(err)
- }
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- if !ok {
- t.Fatal("unpacking packet failed")
- }
- _ = msg.String() // exercise this code path
- if g, e := len(msg.answer), 5; g != e {
- t.Errorf("len(msg.answer) = %d; want %d", g, e)
- }
- for idx, rr := range msg.answer {
- if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e {
- t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e)
- }
- if _, ok := rr.(*dnsRR_SRV); !ok {
- t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr)
- }
- }
- for _, name := range [...]string{
- "_xmpp-server._tcp.google.com.",
- "_XMPP-Server._TCP.Google.COM.",
- "_XMPP-SERVER._TCP.GOOGLE.COM.",
- } {
- _, addrs, err := answer(name, "foo:53", msg, uint16(dnsTypeSRV))
- if err != nil {
- t.Error(err)
- }
- if g, e := len(addrs), 5; g != e {
- t.Errorf("len(addrs) = %d; want %d", g, e)
- t.Logf("addrs = %#v", addrs)
- }
- }
- // repack and unpack.
- data2, ok := msg.Pack()
- msg2 := new(dnsMsg)
- msg2.Unpack(data2)
- switch {
- case !ok:
- t.Error("failed to repack message")
- case !reflect.DeepEqual(msg, msg2):
- t.Error("repacked message differs from original")
- }
-}
-
-func TestDNSParseCorruptSRVReply(t *testing.T) {
- data, err := hex.DecodeString(dnsSRVCorruptReply)
- if err != nil {
- t.Fatal(err)
- }
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- if !ok {
- t.Fatal("unpacking packet failed")
- }
- _ = msg.String() // exercise this code path
- if g, e := len(msg.answer), 5; g != e {
- t.Errorf("len(msg.answer) = %d; want %d", g, e)
- }
- for idx, rr := range msg.answer {
- if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e {
- t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e)
- }
- if idx == 4 {
- if _, ok := rr.(*dnsRR_Header); !ok {
- t.Errorf("answer[%d] = %T; want *dnsRR_Header", idx, rr)
- }
- } else {
- if _, ok := rr.(*dnsRR_SRV); !ok {
- t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr)
- }
- }
- }
- _, addrs, err := answer("_xmpp-server._tcp.google.com.", "foo:53", msg, uint16(dnsTypeSRV))
- if err != nil {
- t.Fatalf("answer: %v", err)
- }
- if g, e := len(addrs), 4; g != e {
- t.Errorf("len(addrs) = %d; want %d", g, e)
- t.Logf("addrs = %#v", addrs)
- }
-}
-
-func TestDNSParseTXTReply(t *testing.T) {
- expectedTxt1 := "v=spf1 redirect=_spf.google.com"
- expectedTxt2 := "v=spf1 ip4:69.63.179.25 ip4:69.63.178.128/25 ip4:69.63.184.0/25 " +
- "ip4:66.220.144.128/25 ip4:66.220.155.0/24 " +
- "ip4:69.171.232.0/25 ip4:66.220.157.0/25 " +
- "ip4:69.171.244.0/24 mx -all"
-
- replies := []string{dnsTXTReply1, dnsTXTReply2}
- expectedTxts := []string{expectedTxt1, expectedTxt2}
-
- for i := range replies {
- data, err := hex.DecodeString(replies[i])
- if err != nil {
- t.Fatal(err)
- }
-
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- if !ok {
- t.Errorf("test %d: unpacking packet failed", i)
- continue
- }
-
- if len(msg.answer) != 1 {
- t.Errorf("test %d: len(rr.answer) = %d; want 1", i, len(msg.answer))
- continue
- }
-
- rr := msg.answer[0]
- rrTXT, ok := rr.(*dnsRR_TXT)
- if !ok {
- t.Errorf("test %d: answer[0] = %T; want *dnsRR_TXT", i, rr)
- continue
- }
-
- if rrTXT.Txt != expectedTxts[i] {
- t.Errorf("test %d: Txt = %s; want %s", i, rrTXT.Txt, expectedTxts[i])
- }
- }
-}
-
-func TestDNSParseTXTCorruptDataLengthReply(t *testing.T) {
- replies := []string{dnsTXTCorruptDataLengthReply1, dnsTXTCorruptDataLengthReply2}
-
- for i := range replies {
- data, err := hex.DecodeString(replies[i])
- if err != nil {
- t.Fatal(err)
- }
-
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- if ok {
- t.Errorf("test %d: expected to fail on unpacking corrupt packet", i)
- }
- }
-}
-
-func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) {
- replies := []string{dnsTXTCorruptTXTLengthReply1, dnsTXTCorruptTXTLengthReply2}
-
- for i := range replies {
- data, err := hex.DecodeString(replies[i])
- if err != nil {
- t.Fatal(err)
- }
-
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- // Unpacking should succeed, but we should just get the header.
- if !ok {
- t.Errorf("test %d: unpacking packet failed", i)
- continue
- }
-
- if len(msg.answer) != 1 {
- t.Errorf("test %d: len(rr.answer) = %d; want 1", i, len(msg.answer))
- continue
- }
-
- rr := msg.answer[0]
- if _, justHeader := rr.(*dnsRR_Header); !justHeader {
- t.Errorf("test %d: rr = %T; expected *dnsRR_Header", i, rr)
- }
- }
-}
-
-func TestIsResponseTo(t *testing.T) {
- // Sample DNS query.
- query := dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- },
- },
- }
-
- resp := query
- resp.response = true
- if !resp.IsResponseTo(&query) {
- t.Error("got false, want true")
- }
-
- badResponses := []dnsMsg{
- // Different ID.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 43,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- },
- },
- },
-
- // Different query name.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.google.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- },
- },
- },
-
- // Different query type.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeAAAA,
- Qclass: dnsClassINET,
- },
- },
- },
-
- // Different query class.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassCSNET,
- },
- },
- },
-
- // No questions.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- },
-
- // Extra questions.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- },
- {
- Name: "www.golang.org.",
- Qtype: dnsTypeAAAA,
- Qclass: dnsClassINET,
- },
- },
- },
- }
-
- for i := range badResponses {
- if badResponses[i].IsResponseTo(&query) {
- t.Errorf("%v: got true, want false", i)
- }
- }
-}
-
-// Valid DNS SRV reply
-const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
- "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
- "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" +
- "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" +
- "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" +
- "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" +
- "72016c06676f6f676c6503636f6d00c00c002100010000012c00210014000014950c78" +
- "6d70702d73657276657231016c06676f6f676c6503636f6d00"
-
-// Corrupt DNS SRV reply, with its final RR having a bogus length
-// (perhaps it was truncated, or it's malicious) The mutation is the
-// capital "FF" below, instead of the proper "21".
-const dnsSRVCorruptReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
- "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
- "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" +
- "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" +
- "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" +
- "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" +
- "72016c06676f6f676c6503636f6d00c00c002100010000012c00FF0014000014950c78" +
- "6d70702d73657276657231016c06676f6f676c6503636f6d00"
-
-// TXT reply with one <character-string>
-const dnsTXTReply1 = "b3458180000100010004000505676d61696c03636f6d0000100001c00c001000010000012c00" +
- "201f763d737066312072656469726563743d5f7370662e676f6f676c652e636f6dc00" +
- "c0002000100025d4c000d036e733406676f6f676c65c012c00c0002000100025d4c00" +
- "06036e7331c057c00c0002000100025d4c0006036e7333c057c00c0002000100025d4" +
- "c0006036e7332c057c06c00010001000248b50004d8ef200ac09000010001000248b5" +
- "0004d8ef220ac07e00010001000248b50004d8ef240ac05300010001000248b50004d" +
- "8ef260a0000291000000000000000"
-
-// TXT reply with more than one <character-string>.
-// See https://tools.ietf.org/html/rfc1035#section-3.3.14
-const dnsTXTReply2 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1000af7f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520692e70343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
-
-// DataLength field should be sum of all TXT fields. In this case it's less.
-const dnsTXTCorruptDataLengthReply1 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1000967f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520692e70343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
-
-// Same as above but DataLength is more than sum of TXT fields.
-const dnsTXTCorruptDataLengthReply2 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1001227f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520692e70343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
-
-// TXT Length field is less than actual length.
-const dnsTXTCorruptTXTLengthReply1 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1000af7f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520691470343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
-
-// TXT Length field is more than actual length.
-const dnsTXTCorruptTXTLengthReply2 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1000af7f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520693370343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
import (
"context"
"sync"
+
+ "golang_org/x/net/dns/dnsmessage"
)
var onceReadProtocols sync.Once
return lookupProtocolMap(name)
}
-func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, error) {
+func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
if err != nil {
return nil, mapErr(err)
}
- if _, ok := c.(PacketConn); ok {
- return &dnsPacketConn{c}, nil
- }
- return &dnsStreamConn{c}, nil
+ return c, nil
}
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
// cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS
}
- addrs, _, err = r.goLookupIPCNAMEOrder(ctx, host, order)
- return
+ ips, _, err := r.goLookupIPCNAMEOrder(ctx, host, order)
+ return ips, err
}
func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
} else {
target = "_" + service + "._" + proto + "." + name
}
- cname, rrs, err := r.lookup(ctx, target, dnsTypeSRV)
+ p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV)
if err != nil {
return "", nil, err
}
- srvs := make([]*SRV, len(rrs))
- for i, rr := range rrs {
- rr := rr.(*dnsRR_SRV)
- srvs[i] = &SRV{Target: rr.Target, Port: rr.Port, Priority: rr.Priority, Weight: rr.Weight}
+ var srvs []*SRV
+ var cname dnsmessage.Name
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return "", nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeSRV {
+ if err := p.SkipAnswer(); err != nil {
+ return "", nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ if cname.Length == 0 && h.Name.Length != 0 {
+ cname = h.Name
+ }
+ srv, err := p.SRVResource()
+ if err != nil {
+ return "", nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
}
byPriorityWeight(srvs).sort()
- return cname, srvs, nil
+ return cname.String(), srvs, nil
}
func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
- _, rrs, err := r.lookup(ctx, name, dnsTypeMX)
+ p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX)
if err != nil {
return nil, err
}
- mxs := make([]*MX, len(rrs))
- for i, rr := range rrs {
- rr := rr.(*dnsRR_MX)
- mxs[i] = &MX{Host: rr.Mx, Pref: rr.Pref}
+ var mxs []*MX
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeMX {
+ if err := p.SkipAnswer(); err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ mx, err := p.MXResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
+
}
byPref(mxs).sort()
return mxs, nil
}
func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
- _, rrs, err := r.lookup(ctx, name, dnsTypeNS)
+ p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS)
if err != nil {
return nil, err
}
- nss := make([]*NS, len(rrs))
- for i, rr := range rrs {
- nss[i] = &NS{Host: rr.(*dnsRR_NS).Ns}
+ var nss []*NS
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeNS {
+ if err := p.SkipAnswer(); err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ ns, err := p.NSResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ nss = append(nss, &NS{Host: ns.NS.String()})
}
return nss, nil
}
func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
- _, rrs, err := r.lookup(ctx, name, dnsTypeTXT)
+ p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT)
if err != nil {
return nil, err
}
- txts := make([]string, len(rrs))
- for i, rr := range rrs {
- txts[i] = rr.(*dnsRR_TXT).Txt
+ var txts []string
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeTXT {
+ if err := p.SkipAnswer(); err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ txt, err := p.TXTResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if len(txts) == 0 {
+ txts = txt.TXT
+ } else {
+ txts = append(txts, txt.TXT...)
+ }
}
return txts, nil
}
--- /dev/null
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dnsmessage_test
+
+import (
+ "fmt"
+ "net"
+ "strings"
+
+ "golang_org/x/net/dns/dnsmessage"
+)
+
+func mustNewName(name string) dnsmessage.Name {
+ n, err := dnsmessage.NewName(name)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func ExampleParser() {
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{Response: true, Authoritative: true},
+ Questions: []dnsmessage.Question{
+ {
+ Name: mustNewName("foo.bar.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ {
+ Name: mustNewName("bar.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ },
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: mustNewName("foo.bar.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}},
+ },
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: mustNewName("bar.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}},
+ },
+ },
+ }
+
+ buf, err := msg.Pack()
+ if err != nil {
+ panic(err)
+ }
+
+ wantName := "bar.example.com."
+
+ var p dnsmessage.Parser
+ if _, err := p.Start(buf); err != nil {
+ panic(err)
+ }
+
+ for {
+ q, err := p.Question()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ panic(err)
+ }
+
+ if q.Name.String() != wantName {
+ continue
+ }
+
+ fmt.Println("Found question for name", wantName)
+ if err := p.SkipAllQuestions(); err != nil {
+ panic(err)
+ }
+ break
+ }
+
+ var gotIPs []net.IP
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ panic(err)
+ }
+
+ if (h.Type != dnsmessage.TypeA && h.Type != dnsmessage.TypeAAAA) || h.Class != dnsmessage.ClassINET {
+ continue
+ }
+
+ if !strings.EqualFold(h.Name.String(), wantName) {
+ if err := p.SkipAnswer(); err != nil {
+ panic(err)
+ }
+ continue
+ }
+
+ switch h.Type {
+ case dnsmessage.TypeA:
+ r, err := p.AResource()
+ if err != nil {
+ panic(err)
+ }
+ gotIPs = append(gotIPs, r.A[:])
+ case dnsmessage.TypeAAAA:
+ r, err := p.AAAAResource()
+ if err != nil {
+ panic(err)
+ }
+ gotIPs = append(gotIPs, r.AAAA[:])
+ }
+ }
+
+ fmt.Printf("Found A/AAAA records for name %s: %v\n", wantName, gotIPs)
+
+ // Output:
+ // Found question for name bar.example.com.
+ // Found A/AAAA records for name bar.example.com.: [127.0.0.2]
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package dnsmessage provides a mostly RFC 1035 compliant implementation of
+// DNS message packing and unpacking.
+//
+// This implementation is designed to minimize heap allocations and avoid
+// unnecessary packing and unpacking as much as possible.
+package dnsmessage
+
+import (
+ "errors"
+)
+
+// Message formats
+
+// A Type is a type of DNS request and response.
+type Type uint16
+
+// A Class is a type of network.
+type Class uint16
+
+// An OpCode is a DNS operation code.
+type OpCode uint16
+
+// An RCode is a DNS response status code.
+type RCode uint16
+
+// Wire constants.
+const (
+ // ResourceHeader.Type and Question.Type
+ TypeA Type = 1
+ TypeNS Type = 2
+ TypeCNAME Type = 5
+ TypeSOA Type = 6
+ TypePTR Type = 12
+ TypeMX Type = 15
+ TypeTXT Type = 16
+ TypeAAAA Type = 28
+ TypeSRV Type = 33
+
+ // Question.Type
+ TypeWKS Type = 11
+ TypeHINFO Type = 13
+ TypeMINFO Type = 14
+ TypeAXFR Type = 252
+ TypeALL Type = 255
+
+ // ResourceHeader.Class and Question.Class
+ ClassINET Class = 1
+ ClassCSNET Class = 2
+ ClassCHAOS Class = 3
+ ClassHESIOD Class = 4
+
+ // Question.Class
+ ClassANY Class = 255
+
+ // Message.Rcode
+ RCodeSuccess RCode = 0
+ RCodeFormatError RCode = 1
+ RCodeServerFailure RCode = 2
+ RCodeNameError RCode = 3
+ RCodeNotImplemented RCode = 4
+ RCodeRefused RCode = 5
+)
+
+var (
+ // ErrNotStarted indicates that the prerequisite information isn't
+ // available yet because the previous records haven't been appropriately
+ // parsed, skipped or finished.
+ ErrNotStarted = errors.New("parsing/packing of this type isn't available yet")
+
+ // ErrSectionDone indicated that all records in the section have been
+ // parsed or finished.
+ ErrSectionDone = errors.New("parsing/packing of this section has completed")
+
+ errBaseLen = errors.New("insufficient data for base length type")
+ errCalcLen = errors.New("insufficient data for calculated length type")
+ errReserved = errors.New("segment prefix is reserved")
+ errTooManyPtr = errors.New("too many pointers (>10)")
+ errInvalidPtr = errors.New("invalid pointer")
+ errNilResouceBody = errors.New("nil resource body")
+ errResourceLen = errors.New("insufficient data for resource body length")
+ errSegTooLong = errors.New("segment length too long")
+ errZeroSegLen = errors.New("zero length segment")
+ errResTooLong = errors.New("resource length too long")
+ errTooManyQuestions = errors.New("too many Questions to pack (>65535)")
+ errTooManyAnswers = errors.New("too many Answers to pack (>65535)")
+ errTooManyAuthorities = errors.New("too many Authorities to pack (>65535)")
+ errTooManyAdditionals = errors.New("too many Additionals to pack (>65535)")
+ errNonCanonicalName = errors.New("name is not in canonical format (it must end with a .)")
+ errStringTooLong = errors.New("character string exceeds maximum length (255)")
+)
+
+// Internal constants.
+const (
+ // packStartingCap is the default initial buffer size allocated during
+ // packing.
+ //
+ // The starting capacity doesn't matter too much, but most DNS responses
+ // Will be <= 512 bytes as it is the limit for DNS over UDP.
+ packStartingCap = 512
+
+ // uint16Len is the length (in bytes) of a uint16.
+ uint16Len = 2
+
+ // uint32Len is the length (in bytes) of a uint32.
+ uint32Len = 4
+
+ // headerLen is the length (in bytes) of a DNS header.
+ //
+ // A header is comprised of 6 uint16s and no padding.
+ headerLen = 6 * uint16Len
+)
+
+type nestedError struct {
+ // s is the current level's error message.
+ s string
+
+ // err is the nested error.
+ err error
+}
+
+// nestedError implements error.Error.
+func (e *nestedError) Error() string {
+ return e.s + ": " + e.err.Error()
+}
+
+// Header is a representation of a DNS message header.
+type Header struct {
+ ID uint16
+ Response bool
+ OpCode OpCode
+ Authoritative bool
+ Truncated bool
+ RecursionDesired bool
+ RecursionAvailable bool
+ RCode RCode
+}
+
+func (m *Header) pack() (id uint16, bits uint16) {
+ id = m.ID
+ bits = uint16(m.OpCode)<<11 | uint16(m.RCode)
+ if m.RecursionAvailable {
+ bits |= headerBitRA
+ }
+ if m.RecursionDesired {
+ bits |= headerBitRD
+ }
+ if m.Truncated {
+ bits |= headerBitTC
+ }
+ if m.Authoritative {
+ bits |= headerBitAA
+ }
+ if m.Response {
+ bits |= headerBitQR
+ }
+ return
+}
+
+// Message is a representation of a DNS message.
+type Message struct {
+ Header
+ Questions []Question
+ Answers []Resource
+ Authorities []Resource
+ Additionals []Resource
+}
+
+type section uint8
+
+const (
+ sectionNotStarted section = iota
+ sectionHeader
+ sectionQuestions
+ sectionAnswers
+ sectionAuthorities
+ sectionAdditionals
+ sectionDone
+
+ headerBitQR = 1 << 15 // query/response (response=1)
+ headerBitAA = 1 << 10 // authoritative
+ headerBitTC = 1 << 9 // truncated
+ headerBitRD = 1 << 8 // recursion desired
+ headerBitRA = 1 << 7 // recursion available
+)
+
+var sectionNames = map[section]string{
+ sectionHeader: "header",
+ sectionQuestions: "Question",
+ sectionAnswers: "Answer",
+ sectionAuthorities: "Authority",
+ sectionAdditionals: "Additional",
+}
+
+// header is the wire format for a DNS message header.
+type header struct {
+ id uint16
+ bits uint16
+ questions uint16
+ answers uint16
+ authorities uint16
+ additionals uint16
+}
+
+func (h *header) count(sec section) uint16 {
+ switch sec {
+ case sectionQuestions:
+ return h.questions
+ case sectionAnswers:
+ return h.answers
+ case sectionAuthorities:
+ return h.authorities
+ case sectionAdditionals:
+ return h.additionals
+ }
+ return 0
+}
+
+// pack appends the wire format of the header to msg.
+func (h *header) pack(msg []byte) []byte {
+ msg = packUint16(msg, h.id)
+ msg = packUint16(msg, h.bits)
+ msg = packUint16(msg, h.questions)
+ msg = packUint16(msg, h.answers)
+ msg = packUint16(msg, h.authorities)
+ return packUint16(msg, h.additionals)
+}
+
+func (h *header) unpack(msg []byte, off int) (int, error) {
+ newOff := off
+ var err error
+ if h.id, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"id", err}
+ }
+ if h.bits, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"bits", err}
+ }
+ if h.questions, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"questions", err}
+ }
+ if h.answers, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"answers", err}
+ }
+ if h.authorities, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"authorities", err}
+ }
+ if h.additionals, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"additionals", err}
+ }
+ return newOff, nil
+}
+
+func (h *header) header() Header {
+ return Header{
+ ID: h.id,
+ Response: (h.bits & headerBitQR) != 0,
+ OpCode: OpCode(h.bits>>11) & 0xF,
+ Authoritative: (h.bits & headerBitAA) != 0,
+ Truncated: (h.bits & headerBitTC) != 0,
+ RecursionDesired: (h.bits & headerBitRD) != 0,
+ RecursionAvailable: (h.bits & headerBitRA) != 0,
+ RCode: RCode(h.bits & 0xF),
+ }
+}
+
+// A Resource is a DNS resource record.
+type Resource struct {
+ Header ResourceHeader
+ Body ResourceBody
+}
+
+// A ResourceBody is a DNS resource record minus the header.
+type ResourceBody interface {
+ // pack packs a Resource except for its header.
+ pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error)
+
+ // realType returns the actual type of the Resource. This is used to
+ // fill in the header Type field.
+ realType() Type
+}
+
+// pack appends the wire format of the Resource to msg.
+func (r *Resource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ if r.Body == nil {
+ return msg, errNilResouceBody
+ }
+ oldMsg := msg
+ r.Header.Type = r.Body.realType()
+ msg, length, err := r.Header.pack(msg, compression, compressionOff)
+ if err != nil {
+ return msg, &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ msg, err = r.Body.pack(msg, compression, compressionOff)
+ if err != nil {
+ return msg, &nestedError{"content", err}
+ }
+ if err := r.Header.fixLen(msg, length, preLen); err != nil {
+ return oldMsg, err
+ }
+ return msg, nil
+}
+
+// A Parser allows incrementally parsing a DNS message.
+//
+// When parsing is started, the Header is parsed. Next, each Question can be
+// either parsed or skipped. Alternatively, all Questions can be skipped at
+// once. When all Questions have been parsed, attempting to parse Questions
+// will return (nil, nil) and attempting to skip Questions will return
+// (true, nil). After all Questions have been either parsed or skipped, all
+// Answers, Authorities and Additionals can be either parsed or skipped in the
+// same way, and each type of Resource must be fully parsed or skipped before
+// proceeding to the next type of Resource.
+//
+// Note that there is no requirement to fully skip or parse the message.
+type Parser struct {
+ msg []byte
+ header header
+
+ section section
+ off int
+ index int
+ resHeaderValid bool
+ resHeader ResourceHeader
+}
+
+// Start parses the header and enables the parsing of Questions.
+func (p *Parser) Start(msg []byte) (Header, error) {
+ if p.msg != nil {
+ *p = Parser{}
+ }
+ p.msg = msg
+ var err error
+ if p.off, err = p.header.unpack(msg, 0); err != nil {
+ return Header{}, &nestedError{"unpacking header", err}
+ }
+ p.section = sectionQuestions
+ return p.header.header(), nil
+}
+
+func (p *Parser) checkAdvance(sec section) error {
+ if p.section < sec {
+ return ErrNotStarted
+ }
+ if p.section > sec {
+ return ErrSectionDone
+ }
+ p.resHeaderValid = false
+ if p.index == int(p.header.count(sec)) {
+ p.index = 0
+ p.section++
+ return ErrSectionDone
+ }
+ return nil
+}
+
+func (p *Parser) resource(sec section) (Resource, error) {
+ var r Resource
+ var err error
+ r.Header, err = p.resourceHeader(sec)
+ if err != nil {
+ return r, err
+ }
+ p.resHeaderValid = false
+ r.Body, p.off, err = unpackResourceBody(p.msg, p.off, r.Header)
+ if err != nil {
+ return Resource{}, &nestedError{"unpacking " + sectionNames[sec], err}
+ }
+ p.index++
+ return r, nil
+}
+
+func (p *Parser) resourceHeader(sec section) (ResourceHeader, error) {
+ if p.resHeaderValid {
+ return p.resHeader, nil
+ }
+ if err := p.checkAdvance(sec); err != nil {
+ return ResourceHeader{}, err
+ }
+ var hdr ResourceHeader
+ off, err := hdr.unpack(p.msg, p.off)
+ if err != nil {
+ return ResourceHeader{}, err
+ }
+ p.resHeaderValid = true
+ p.resHeader = hdr
+ p.off = off
+ return hdr, nil
+}
+
+func (p *Parser) skipResource(sec section) error {
+ if p.resHeaderValid {
+ newOff := p.off + int(p.resHeader.Length)
+ if newOff > len(p.msg) {
+ return errResourceLen
+ }
+ p.off = newOff
+ p.resHeaderValid = false
+ p.index++
+ return nil
+ }
+ if err := p.checkAdvance(sec); err != nil {
+ return err
+ }
+ var err error
+ p.off, err = skipResource(p.msg, p.off)
+ if err != nil {
+ return &nestedError{"skipping: " + sectionNames[sec], err}
+ }
+ p.index++
+ return nil
+}
+
+// Question parses a single Question.
+func (p *Parser) Question() (Question, error) {
+ if err := p.checkAdvance(sectionQuestions); err != nil {
+ return Question{}, err
+ }
+ var name Name
+ off, err := name.unpack(p.msg, p.off)
+ if err != nil {
+ return Question{}, &nestedError{"unpacking Question.Name", err}
+ }
+ typ, off, err := unpackType(p.msg, off)
+ if err != nil {
+ return Question{}, &nestedError{"unpacking Question.Type", err}
+ }
+ class, off, err := unpackClass(p.msg, off)
+ if err != nil {
+ return Question{}, &nestedError{"unpacking Question.Class", err}
+ }
+ p.off = off
+ p.index++
+ return Question{name, typ, class}, nil
+}
+
+// AllQuestions parses all Questions.
+func (p *Parser) AllQuestions() ([]Question, error) {
+ // Multiple questions are valid according to the spec,
+ // but servers don't actually support them. There will
+ // be at most one question here.
+ //
+ // Do not pre-allocate based on info in p.header, since
+ // the data is untrusted.
+ qs := []Question{}
+ for {
+ q, err := p.Question()
+ if err == ErrSectionDone {
+ return qs, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ qs = append(qs, q)
+ }
+}
+
+// SkipQuestion skips a single Question.
+func (p *Parser) SkipQuestion() error {
+ if err := p.checkAdvance(sectionQuestions); err != nil {
+ return err
+ }
+ off, err := skipName(p.msg, p.off)
+ if err != nil {
+ return &nestedError{"skipping Question Name", err}
+ }
+ if off, err = skipType(p.msg, off); err != nil {
+ return &nestedError{"skipping Question Type", err}
+ }
+ if off, err = skipClass(p.msg, off); err != nil {
+ return &nestedError{"skipping Question Class", err}
+ }
+ p.off = off
+ p.index++
+ return nil
+}
+
+// SkipAllQuestions skips all Questions.
+func (p *Parser) SkipAllQuestions() error {
+ for {
+ if err := p.SkipQuestion(); err == ErrSectionDone {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+// AnswerHeader parses a single Answer ResourceHeader.
+func (p *Parser) AnswerHeader() (ResourceHeader, error) {
+ return p.resourceHeader(sectionAnswers)
+}
+
+// Answer parses a single Answer Resource.
+func (p *Parser) Answer() (Resource, error) {
+ return p.resource(sectionAnswers)
+}
+
+// AllAnswers parses all Answer Resources.
+func (p *Parser) AllAnswers() ([]Resource, error) {
+ // The most common query is for A/AAAA, which usually returns
+ // a handful of IPs.
+ //
+ // Pre-allocate up to a certain limit, since p.header is
+ // untrusted data.
+ n := int(p.header.answers)
+ if n > 20 {
+ n = 20
+ }
+ as := make([]Resource, 0, n)
+ for {
+ a, err := p.Answer()
+ if err == ErrSectionDone {
+ return as, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ as = append(as, a)
+ }
+}
+
+// SkipAnswer skips a single Answer Resource.
+func (p *Parser) SkipAnswer() error {
+ return p.skipResource(sectionAnswers)
+}
+
+// SkipAllAnswers skips all Answer Resources.
+func (p *Parser) SkipAllAnswers() error {
+ for {
+ if err := p.SkipAnswer(); err == ErrSectionDone {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+// AuthorityHeader parses a single Authority ResourceHeader.
+func (p *Parser) AuthorityHeader() (ResourceHeader, error) {
+ return p.resourceHeader(sectionAuthorities)
+}
+
+// Authority parses a single Authority Resource.
+func (p *Parser) Authority() (Resource, error) {
+ return p.resource(sectionAuthorities)
+}
+
+// AllAuthorities parses all Authority Resources.
+func (p *Parser) AllAuthorities() ([]Resource, error) {
+ // Authorities contains SOA in case of NXDOMAIN and friends,
+ // otherwise it is empty.
+ //
+ // Pre-allocate up to a certain limit, since p.header is
+ // untrusted data.
+ n := int(p.header.authorities)
+ if n > 10 {
+ n = 10
+ }
+ as := make([]Resource, 0, n)
+ for {
+ a, err := p.Authority()
+ if err == ErrSectionDone {
+ return as, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ as = append(as, a)
+ }
+}
+
+// SkipAuthority skips a single Authority Resource.
+func (p *Parser) SkipAuthority() error {
+ return p.skipResource(sectionAuthorities)
+}
+
+// SkipAllAuthorities skips all Authority Resources.
+func (p *Parser) SkipAllAuthorities() error {
+ for {
+ if err := p.SkipAuthority(); err == ErrSectionDone {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+// AdditionalHeader parses a single Additional ResourceHeader.
+func (p *Parser) AdditionalHeader() (ResourceHeader, error) {
+ return p.resourceHeader(sectionAdditionals)
+}
+
+// Additional parses a single Additional Resource.
+func (p *Parser) Additional() (Resource, error) {
+ return p.resource(sectionAdditionals)
+}
+
+// AllAdditionals parses all Additional Resources.
+func (p *Parser) AllAdditionals() ([]Resource, error) {
+ // Additionals usually contain OPT, and sometimes A/AAAA
+ // glue records.
+ //
+ // Pre-allocate up to a certain limit, since p.header is
+ // untrusted data.
+ n := int(p.header.additionals)
+ if n > 10 {
+ n = 10
+ }
+ as := make([]Resource, 0, n)
+ for {
+ a, err := p.Additional()
+ if err == ErrSectionDone {
+ return as, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ as = append(as, a)
+ }
+}
+
+// SkipAdditional skips a single Additional Resource.
+func (p *Parser) SkipAdditional() error {
+ return p.skipResource(sectionAdditionals)
+}
+
+// SkipAllAdditionals skips all Additional Resources.
+func (p *Parser) SkipAllAdditionals() error {
+ for {
+ if err := p.SkipAdditional(); err == ErrSectionDone {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+// CNAMEResource parses a single CNAMEResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) CNAMEResource() (CNAMEResource, error) {
+ if !p.resHeaderValid || p.resHeader.Type != TypeCNAME {
+ return CNAMEResource{}, ErrNotStarted
+ }
+ r, err := unpackCNAMEResource(p.msg, p.off)
+ if err != nil {
+ return CNAMEResource{}, err
+ }
+ p.off += int(p.resHeader.Length)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// MXResource parses a single MXResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) MXResource() (MXResource, error) {
+ if !p.resHeaderValid || p.resHeader.Type != TypeMX {
+ return MXResource{}, ErrNotStarted
+ }
+ r, err := unpackMXResource(p.msg, p.off)
+ if err != nil {
+ return MXResource{}, err
+ }
+ p.off += int(p.resHeader.Length)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// NSResource parses a single NSResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) NSResource() (NSResource, error) {
+ if !p.resHeaderValid || p.resHeader.Type != TypeNS {
+ return NSResource{}, ErrNotStarted
+ }
+ r, err := unpackNSResource(p.msg, p.off)
+ if err != nil {
+ return NSResource{}, err
+ }
+ p.off += int(p.resHeader.Length)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// PTRResource parses a single PTRResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) PTRResource() (PTRResource, error) {
+ if !p.resHeaderValid || p.resHeader.Type != TypePTR {
+ return PTRResource{}, ErrNotStarted
+ }
+ r, err := unpackPTRResource(p.msg, p.off)
+ if err != nil {
+ return PTRResource{}, err
+ }
+ p.off += int(p.resHeader.Length)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// SOAResource parses a single SOAResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) SOAResource() (SOAResource, error) {
+ if !p.resHeaderValid || p.resHeader.Type != TypeSOA {
+ return SOAResource{}, ErrNotStarted
+ }
+ r, err := unpackSOAResource(p.msg, p.off)
+ if err != nil {
+ return SOAResource{}, err
+ }
+ p.off += int(p.resHeader.Length)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// TXTResource parses a single TXTResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) TXTResource() (TXTResource, error) {
+ if !p.resHeaderValid || p.resHeader.Type != TypeTXT {
+ return TXTResource{}, ErrNotStarted
+ }
+ r, err := unpackTXTResource(p.msg, p.off, p.resHeader.Length)
+ if err != nil {
+ return TXTResource{}, err
+ }
+ p.off += int(p.resHeader.Length)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// SRVResource parses a single SRVResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) SRVResource() (SRVResource, error) {
+ if !p.resHeaderValid || p.resHeader.Type != TypeSRV {
+ return SRVResource{}, ErrNotStarted
+ }
+ r, err := unpackSRVResource(p.msg, p.off)
+ if err != nil {
+ return SRVResource{}, err
+ }
+ p.off += int(p.resHeader.Length)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// AResource parses a single AResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) AResource() (AResource, error) {
+ if !p.resHeaderValid || p.resHeader.Type != TypeA {
+ return AResource{}, ErrNotStarted
+ }
+ r, err := unpackAResource(p.msg, p.off)
+ if err != nil {
+ return AResource{}, err
+ }
+ p.off += int(p.resHeader.Length)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// AAAAResource parses a single AAAAResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) AAAAResource() (AAAAResource, error) {
+ if !p.resHeaderValid || p.resHeader.Type != TypeAAAA {
+ return AAAAResource{}, ErrNotStarted
+ }
+ r, err := unpackAAAAResource(p.msg, p.off)
+ if err != nil {
+ return AAAAResource{}, err
+ }
+ p.off += int(p.resHeader.Length)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// Unpack parses a full Message.
+func (m *Message) Unpack(msg []byte) error {
+ var p Parser
+ var err error
+ if m.Header, err = p.Start(msg); err != nil {
+ return err
+ }
+ if m.Questions, err = p.AllQuestions(); err != nil {
+ return err
+ }
+ if m.Answers, err = p.AllAnswers(); err != nil {
+ return err
+ }
+ if m.Authorities, err = p.AllAuthorities(); err != nil {
+ return err
+ }
+ if m.Additionals, err = p.AllAdditionals(); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Pack packs a full Message.
+func (m *Message) Pack() ([]byte, error) {
+ return m.AppendPack(make([]byte, 0, packStartingCap))
+}
+
+// AppendPack is like Pack but appends the full Message to b and returns the
+// extended buffer.
+func (m *Message) AppendPack(b []byte) ([]byte, error) {
+ // Validate the lengths. It is very unlikely that anyone will try to
+ // pack more than 65535 of any particular type, but it is possible and
+ // we should fail gracefully.
+ if len(m.Questions) > int(^uint16(0)) {
+ return nil, errTooManyQuestions
+ }
+ if len(m.Answers) > int(^uint16(0)) {
+ return nil, errTooManyAnswers
+ }
+ if len(m.Authorities) > int(^uint16(0)) {
+ return nil, errTooManyAuthorities
+ }
+ if len(m.Additionals) > int(^uint16(0)) {
+ return nil, errTooManyAdditionals
+ }
+
+ var h header
+ h.id, h.bits = m.Header.pack()
+
+ h.questions = uint16(len(m.Questions))
+ h.answers = uint16(len(m.Answers))
+ h.authorities = uint16(len(m.Authorities))
+ h.additionals = uint16(len(m.Additionals))
+
+ compressionOff := len(b)
+ msg := h.pack(b)
+
+ // RFC 1035 allows (but does not require) compression for packing. RFC
+ // 1035 requires unpacking implementations to support compression, so
+ // unconditionally enabling it is fine.
+ //
+ // DNS lookups are typically done over UDP, and RFC 1035 states that UDP
+ // DNS messages can be a maximum of 512 bytes long. Without compression,
+ // many DNS response messages are over this limit, so enabling
+ // compression will help ensure compliance.
+ compression := map[string]int{}
+
+ for i := range m.Questions {
+ var err error
+ if msg, err = m.Questions[i].pack(msg, compression, compressionOff); err != nil {
+ return nil, &nestedError{"packing Question", err}
+ }
+ }
+ for i := range m.Answers {
+ var err error
+ if msg, err = m.Answers[i].pack(msg, compression, compressionOff); err != nil {
+ return nil, &nestedError{"packing Answer", err}
+ }
+ }
+ for i := range m.Authorities {
+ var err error
+ if msg, err = m.Authorities[i].pack(msg, compression, compressionOff); err != nil {
+ return nil, &nestedError{"packing Authority", err}
+ }
+ }
+ for i := range m.Additionals {
+ var err error
+ if msg, err = m.Additionals[i].pack(msg, compression, compressionOff); err != nil {
+ return nil, &nestedError{"packing Additional", err}
+ }
+ }
+
+ return msg, nil
+}
+
+// A Builder allows incrementally packing a DNS message.
+//
+// Example usage:
+// buf := make([]byte, 2, 514)
+// b := NewBuilder(buf, Header{...})
+// b.EnableCompression()
+// // Optionally start a section and add things to that section.
+// // Repeat adding sections as necessary.
+// buf, err := b.Finish()
+// // If err is nil, buf[2:] will contain the built bytes.
+type Builder struct {
+ // msg is the storage for the message being built.
+ msg []byte
+
+ // section keeps track of the current section being built.
+ section section
+
+ // header keeps track of what should go in the header when Finish is
+ // called.
+ header header
+
+ // start is the starting index of the bytes allocated in msg for header.
+ start int
+
+ // compression is a mapping from name suffixes to their starting index
+ // in msg.
+ compression map[string]int
+}
+
+// NewBuilder creates a new builder with compression disabled.
+//
+// Note: Most users will want to immediately enable compression with the
+// EnableCompression method. See that method's comment for why you may or may
+// not want to enable compression.
+//
+// The DNS message is appended to the provided initial buffer buf (which may be
+// nil) as it is built. The final message is returned by the (*Builder).Finish
+// method, which may return the same underlying array if there was sufficient
+// capacity in the slice.
+func NewBuilder(buf []byte, h Header) Builder {
+ if buf == nil {
+ buf = make([]byte, 0, packStartingCap)
+ }
+ b := Builder{msg: buf, start: len(buf)}
+ b.header.id, b.header.bits = h.pack()
+ var hb [headerLen]byte
+ b.msg = append(b.msg, hb[:]...)
+ b.section = sectionHeader
+ return b
+}
+
+// EnableCompression enables compression in the Builder.
+//
+// Leaving compression disabled avoids compression related allocations, but can
+// result in larger message sizes. Be careful with this mode as it can cause
+// messages to exceed the UDP size limit.
+//
+// According to RFC 1035, section 4.1.4, the use of compression is optional, but
+// all implementations must accept both compressed and uncompressed DNS
+// messages.
+//
+// Compression should be enabled before any sections are added for best results.
+func (b *Builder) EnableCompression() {
+ b.compression = map[string]int{}
+}
+
+func (b *Builder) startCheck(s section) error {
+ if b.section <= sectionNotStarted {
+ return ErrNotStarted
+ }
+ if b.section > s {
+ return ErrSectionDone
+ }
+ return nil
+}
+
+// StartQuestions prepares the builder for packing Questions.
+func (b *Builder) StartQuestions() error {
+ if err := b.startCheck(sectionQuestions); err != nil {
+ return err
+ }
+ b.section = sectionQuestions
+ return nil
+}
+
+// StartAnswers prepares the builder for packing Answers.
+func (b *Builder) StartAnswers() error {
+ if err := b.startCheck(sectionAnswers); err != nil {
+ return err
+ }
+ b.section = sectionAnswers
+ return nil
+}
+
+// StartAuthorities prepares the builder for packing Authorities.
+func (b *Builder) StartAuthorities() error {
+ if err := b.startCheck(sectionAuthorities); err != nil {
+ return err
+ }
+ b.section = sectionAuthorities
+ return nil
+}
+
+// StartAdditionals prepares the builder for packing Additionals.
+func (b *Builder) StartAdditionals() error {
+ if err := b.startCheck(sectionAdditionals); err != nil {
+ return err
+ }
+ b.section = sectionAdditionals
+ return nil
+}
+
+func (b *Builder) incrementSectionCount() error {
+ var count *uint16
+ var err error
+ switch b.section {
+ case sectionQuestions:
+ count = &b.header.questions
+ err = errTooManyQuestions
+ case sectionAnswers:
+ count = &b.header.answers
+ err = errTooManyAnswers
+ case sectionAuthorities:
+ count = &b.header.authorities
+ err = errTooManyAuthorities
+ case sectionAdditionals:
+ count = &b.header.additionals
+ err = errTooManyAdditionals
+ }
+ if *count == ^uint16(0) {
+ return err
+ }
+ *count++
+ return nil
+}
+
+// Question adds a single Question.
+func (b *Builder) Question(q Question) error {
+ if b.section < sectionQuestions {
+ return ErrNotStarted
+ }
+ if b.section > sectionQuestions {
+ return ErrSectionDone
+ }
+ msg, err := q.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+func (b *Builder) checkResourceSection() error {
+ if b.section < sectionAnswers {
+ return ErrNotStarted
+ }
+ if b.section > sectionAdditionals {
+ return ErrSectionDone
+ }
+ return nil
+}
+
+// CNAMEResource adds a single CNAMEResource.
+func (b *Builder) CNAMEResource(h ResourceHeader, r CNAMEResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ h.Type = r.realType()
+ msg, length, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"CNAMEResource body", err}
+ }
+ if err := h.fixLen(msg, length, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// MXResource adds a single MXResource.
+func (b *Builder) MXResource(h ResourceHeader, r MXResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ h.Type = r.realType()
+ msg, length, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"MXResource body", err}
+ }
+ if err := h.fixLen(msg, length, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// NSResource adds a single NSResource.
+func (b *Builder) NSResource(h ResourceHeader, r NSResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ h.Type = r.realType()
+ msg, length, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"NSResource body", err}
+ }
+ if err := h.fixLen(msg, length, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// PTRResource adds a single PTRResource.
+func (b *Builder) PTRResource(h ResourceHeader, r PTRResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ h.Type = r.realType()
+ msg, length, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"PTRResource body", err}
+ }
+ if err := h.fixLen(msg, length, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// SOAResource adds a single SOAResource.
+func (b *Builder) SOAResource(h ResourceHeader, r SOAResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ h.Type = r.realType()
+ msg, length, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"SOAResource body", err}
+ }
+ if err := h.fixLen(msg, length, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// TXTResource adds a single TXTResource.
+func (b *Builder) TXTResource(h ResourceHeader, r TXTResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ h.Type = r.realType()
+ msg, length, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"TXTResource body", err}
+ }
+ if err := h.fixLen(msg, length, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// SRVResource adds a single SRVResource.
+func (b *Builder) SRVResource(h ResourceHeader, r SRVResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ h.Type = r.realType()
+ msg, length, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"SRVResource body", err}
+ }
+ if err := h.fixLen(msg, length, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// AResource adds a single AResource.
+func (b *Builder) AResource(h ResourceHeader, r AResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ h.Type = r.realType()
+ msg, length, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"AResource body", err}
+ }
+ if err := h.fixLen(msg, length, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// AAAAResource adds a single AAAAResource.
+func (b *Builder) AAAAResource(h ResourceHeader, r AAAAResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ h.Type = r.realType()
+ msg, length, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"AAAAResource body", err}
+ }
+ if err := h.fixLen(msg, length, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// Finish ends message building and generates a binary message.
+func (b *Builder) Finish() ([]byte, error) {
+ if b.section < sectionHeader {
+ return nil, ErrNotStarted
+ }
+ b.section = sectionDone
+ // Space for the header was allocated in NewBuilder.
+ b.header.pack(b.msg[b.start:b.start])
+ return b.msg, nil
+}
+
+// A ResourceHeader is the header of a DNS resource record. There are
+// many types of DNS resource records, but they all share the same header.
+type ResourceHeader struct {
+ // Name is the domain name for which this resource record pertains.
+ Name Name
+
+ // Type is the type of DNS resource record.
+ //
+ // This field will be set automatically during packing.
+ Type Type
+
+ // Class is the class of network to which this DNS resource record
+ // pertains.
+ Class Class
+
+ // TTL is the length of time (measured in seconds) which this resource
+ // record is valid for (time to live). All Resources in a set should
+ // have the same TTL (RFC 2181 Section 5.2).
+ TTL uint32
+
+ // Length is the length of data in the resource record after the header.
+ //
+ // This field will be set automatically during packing.
+ Length uint16
+}
+
+// pack appends the wire format of the ResourceHeader to oldMsg.
+//
+// The bytes where length was packed are returned as a slice so they can be
+// updated after the rest of the Resource has been packed.
+func (h *ResourceHeader) pack(oldMsg []byte, compression map[string]int, compressionOff int) (msg []byte, length []byte, err error) {
+ msg = oldMsg
+ if msg, err = h.Name.pack(msg, compression, compressionOff); err != nil {
+ return oldMsg, nil, &nestedError{"Name", err}
+ }
+ msg = packType(msg, h.Type)
+ msg = packClass(msg, h.Class)
+ msg = packUint32(msg, h.TTL)
+ lenBegin := len(msg)
+ msg = packUint16(msg, h.Length)
+ return msg, msg[lenBegin : lenBegin+uint16Len], nil
+}
+
+func (h *ResourceHeader) unpack(msg []byte, off int) (int, error) {
+ newOff := off
+ var err error
+ if newOff, err = h.Name.unpack(msg, newOff); err != nil {
+ return off, &nestedError{"Name", err}
+ }
+ if h.Type, newOff, err = unpackType(msg, newOff); err != nil {
+ return off, &nestedError{"Type", err}
+ }
+ if h.Class, newOff, err = unpackClass(msg, newOff); err != nil {
+ return off, &nestedError{"Class", err}
+ }
+ if h.TTL, newOff, err = unpackUint32(msg, newOff); err != nil {
+ return off, &nestedError{"TTL", err}
+ }
+ if h.Length, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"Length", err}
+ }
+ return newOff, nil
+}
+
+func (h *ResourceHeader) fixLen(msg []byte, length []byte, preLen int) error {
+ conLen := len(msg) - preLen
+ if conLen > int(^uint16(0)) {
+ return errResTooLong
+ }
+
+ // Fill in the length now that we know how long the content is.
+ packUint16(length[:0], uint16(conLen))
+ h.Length = uint16(conLen)
+
+ return nil
+}
+
+func skipResource(msg []byte, off int) (int, error) {
+ newOff, err := skipName(msg, off)
+ if err != nil {
+ return off, &nestedError{"Name", err}
+ }
+ if newOff, err = skipType(msg, newOff); err != nil {
+ return off, &nestedError{"Type", err}
+ }
+ if newOff, err = skipClass(msg, newOff); err != nil {
+ return off, &nestedError{"Class", err}
+ }
+ if newOff, err = skipUint32(msg, newOff); err != nil {
+ return off, &nestedError{"TTL", err}
+ }
+ length, newOff, err := unpackUint16(msg, newOff)
+ if err != nil {
+ return off, &nestedError{"Length", err}
+ }
+ if newOff += int(length); newOff > len(msg) {
+ return off, errResourceLen
+ }
+ return newOff, nil
+}
+
+// packUint16 appends the wire format of field to msg.
+func packUint16(msg []byte, field uint16) []byte {
+ return append(msg, byte(field>>8), byte(field))
+}
+
+func unpackUint16(msg []byte, off int) (uint16, int, error) {
+ if off+uint16Len > len(msg) {
+ return 0, off, errBaseLen
+ }
+ return uint16(msg[off])<<8 | uint16(msg[off+1]), off + uint16Len, nil
+}
+
+func skipUint16(msg []byte, off int) (int, error) {
+ if off+uint16Len > len(msg) {
+ return off, errBaseLen
+ }
+ return off + uint16Len, nil
+}
+
+// packType appends the wire format of field to msg.
+func packType(msg []byte, field Type) []byte {
+ return packUint16(msg, uint16(field))
+}
+
+func unpackType(msg []byte, off int) (Type, int, error) {
+ t, o, err := unpackUint16(msg, off)
+ return Type(t), o, err
+}
+
+func skipType(msg []byte, off int) (int, error) {
+ return skipUint16(msg, off)
+}
+
+// packClass appends the wire format of field to msg.
+func packClass(msg []byte, field Class) []byte {
+ return packUint16(msg, uint16(field))
+}
+
+func unpackClass(msg []byte, off int) (Class, int, error) {
+ c, o, err := unpackUint16(msg, off)
+ return Class(c), o, err
+}
+
+func skipClass(msg []byte, off int) (int, error) {
+ return skipUint16(msg, off)
+}
+
+// packUint32 appends the wire format of field to msg.
+func packUint32(msg []byte, field uint32) []byte {
+ return append(
+ msg,
+ byte(field>>24),
+ byte(field>>16),
+ byte(field>>8),
+ byte(field),
+ )
+}
+
+func unpackUint32(msg []byte, off int) (uint32, int, error) {
+ if off+uint32Len > len(msg) {
+ return 0, off, errBaseLen
+ }
+ v := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
+ return v, off + uint32Len, nil
+}
+
+func skipUint32(msg []byte, off int) (int, error) {
+ if off+uint32Len > len(msg) {
+ return off, errBaseLen
+ }
+ return off + uint32Len, nil
+}
+
+// packText appends the wire format of field to msg.
+func packText(msg []byte, field string) ([]byte, error) {
+ l := len(field)
+ if l > 255 {
+ return nil, errStringTooLong
+ }
+ msg = append(msg, byte(l))
+ msg = append(msg, field...)
+
+ return msg, nil
+}
+
+func unpackText(msg []byte, off int) (string, int, error) {
+ if off >= len(msg) {
+ return "", off, errBaseLen
+ }
+ beginOff := off + 1
+ endOff := beginOff + int(msg[off])
+ if endOff > len(msg) {
+ return "", off, errCalcLen
+ }
+ return string(msg[beginOff:endOff]), endOff, nil
+}
+
+func skipText(msg []byte, off int) (int, error) {
+ if off >= len(msg) {
+ return off, errBaseLen
+ }
+ endOff := off + 1 + int(msg[off])
+ if endOff > len(msg) {
+ return off, errCalcLen
+ }
+ return endOff, nil
+}
+
+// packBytes appends the wire format of field to msg.
+func packBytes(msg []byte, field []byte) []byte {
+ return append(msg, field...)
+}
+
+func unpackBytes(msg []byte, off int, field []byte) (int, error) {
+ newOff := off + len(field)
+ if newOff > len(msg) {
+ return off, errBaseLen
+ }
+ copy(field, msg[off:newOff])
+ return newOff, nil
+}
+
+func skipBytes(msg []byte, off int, field []byte) (int, error) {
+ newOff := off + len(field)
+ if newOff > len(msg) {
+ return off, errBaseLen
+ }
+ return newOff, nil
+}
+
+const nameLen = 255
+
+// A Name is a non-encoded domain name. It is used instead of strings to avoid
+// allocations.
+type Name struct {
+ Data [nameLen]byte
+ Length uint8
+}
+
+// NewName creates a new Name from a string.
+func NewName(name string) (Name, error) {
+ if len([]byte(name)) > nameLen {
+ return Name{}, errCalcLen
+ }
+ n := Name{Length: uint8(len(name))}
+ copy(n.Data[:], []byte(name))
+ return n, nil
+}
+
+func (n Name) String() string {
+ return string(n.Data[:n.Length])
+}
+
+// pack appends the wire format of the Name to msg.
+//
+// Domain names are a sequence of counted strings split at the dots. They end
+// with a zero-length string. Compression can be used to reuse domain suffixes.
+//
+// The compression map will be updated with new domain suffixes. If compression
+// is nil, compression will not be used.
+func (n *Name) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ oldMsg := msg
+
+ // Add a trailing dot to canonicalize name.
+ if n.Length == 0 || n.Data[n.Length-1] != '.' {
+ return oldMsg, errNonCanonicalName
+ }
+
+ // Allow root domain.
+ if n.Data[0] == '.' && n.Length == 1 {
+ return append(msg, 0), nil
+ }
+
+ // Emit sequence of counted strings, chopping at dots.
+ for i, begin := 0, 0; i < int(n.Length); i++ {
+ // Check for the end of the segment.
+ if n.Data[i] == '.' {
+ // The two most significant bits have special meaning.
+ // It isn't allowed for segments to be long enough to
+ // need them.
+ if i-begin >= 1<<6 {
+ return oldMsg, errSegTooLong
+ }
+
+ // Segments must have a non-zero length.
+ if i-begin == 0 {
+ return oldMsg, errZeroSegLen
+ }
+
+ msg = append(msg, byte(i-begin))
+
+ for j := begin; j < i; j++ {
+ msg = append(msg, n.Data[j])
+ }
+
+ begin = i + 1
+ continue
+ }
+
+ // We can only compress domain suffixes starting with a new
+ // segment. A pointer is two bytes with the two most significant
+ // bits set to 1 to indicate that it is a pointer.
+ if (i == 0 || n.Data[i-1] == '.') && compression != nil {
+ if ptr, ok := compression[string(n.Data[i:])]; ok {
+ // Hit. Emit a pointer instead of the rest of
+ // the domain.
+ return append(msg, byte(ptr>>8|0xC0), byte(ptr)), nil
+ }
+
+ // Miss. Add the suffix to the compression table if the
+ // offset can be stored in the available 14 bytes.
+ if len(msg) <= int(^uint16(0)>>2) {
+ compression[string(n.Data[i:])] = len(msg) - compressionOff
+ }
+ }
+ }
+ return append(msg, 0), nil
+}
+
+// unpack unpacks a domain name.
+func (n *Name) unpack(msg []byte, off int) (int, error) {
+ // currOff is the current working offset.
+ currOff := off
+
+ // newOff is the offset where the next record will start. Pointers lead
+ // to data that belongs to other names and thus doesn't count towards to
+ // the usage of this name.
+ newOff := off
+
+ // ptr is the number of pointers followed.
+ var ptr int
+
+ // Name is a slice representation of the name data.
+ name := n.Data[:0]
+
+Loop:
+ for {
+ if currOff >= len(msg) {
+ return off, errBaseLen
+ }
+ c := int(msg[currOff])
+ currOff++
+ switch c & 0xC0 {
+ case 0x00: // String segment
+ if c == 0x00 {
+ // A zero length signals the end of the name.
+ break Loop
+ }
+ endOff := currOff + c
+ if endOff > len(msg) {
+ return off, errCalcLen
+ }
+ name = append(name, msg[currOff:endOff]...)
+ name = append(name, '.')
+ currOff = endOff
+ case 0xC0: // Pointer
+ if currOff >= len(msg) {
+ return off, errInvalidPtr
+ }
+ c1 := msg[currOff]
+ currOff++
+ if ptr == 0 {
+ newOff = currOff
+ }
+ // Don't follow too many pointers, maybe there's a loop.
+ if ptr++; ptr > 10 {
+ return off, errTooManyPtr
+ }
+ currOff = (c^0xC0)<<8 | int(c1)
+ default:
+ // Prefixes 0x80 and 0x40 are reserved.
+ return off, errReserved
+ }
+ }
+ if len(name) == 0 {
+ name = append(name, '.')
+ }
+ if len(name) > len(n.Data) {
+ return off, errCalcLen
+ }
+ n.Length = uint8(len(name))
+ if ptr == 0 {
+ newOff = currOff
+ }
+ return newOff, nil
+}
+
+func skipName(msg []byte, off int) (int, error) {
+ // newOff is the offset where the next record will start. Pointers lead
+ // to data that belongs to other names and thus doesn't count towards to
+ // the usage of this name.
+ newOff := off
+
+Loop:
+ for {
+ if newOff >= len(msg) {
+ return off, errBaseLen
+ }
+ c := int(msg[newOff])
+ newOff++
+ switch c & 0xC0 {
+ case 0x00:
+ if c == 0x00 {
+ // A zero length signals the end of the name.
+ break Loop
+ }
+ // literal string
+ newOff += c
+ if newOff > len(msg) {
+ return off, errCalcLen
+ }
+ case 0xC0:
+ // Pointer to somewhere else in msg.
+
+ // Pointers are two bytes.
+ newOff++
+
+ // Don't follow the pointer as the data here has ended.
+ break Loop
+ default:
+ // Prefixes 0x80 and 0x40 are reserved.
+ return off, errReserved
+ }
+ }
+
+ return newOff, nil
+}
+
+// A Question is a DNS query.
+type Question struct {
+ Name Name
+ Type Type
+ Class Class
+}
+
+// pack appends the wire format of the Question to msg.
+func (q *Question) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ msg, err := q.Name.pack(msg, compression, compressionOff)
+ if err != nil {
+ return msg, &nestedError{"Name", err}
+ }
+ msg = packType(msg, q.Type)
+ return packClass(msg, q.Class), nil
+}
+
+func unpackResourceBody(msg []byte, off int, hdr ResourceHeader) (ResourceBody, int, error) {
+ var (
+ r ResourceBody
+ err error
+ name string
+ )
+ switch hdr.Type {
+ case TypeA:
+ var rb AResource
+ rb, err = unpackAResource(msg, off)
+ r = &rb
+ name = "A"
+ case TypeNS:
+ var rb NSResource
+ rb, err = unpackNSResource(msg, off)
+ r = &rb
+ name = "NS"
+ case TypeCNAME:
+ var rb CNAMEResource
+ rb, err = unpackCNAMEResource(msg, off)
+ r = &rb
+ name = "CNAME"
+ case TypeSOA:
+ var rb SOAResource
+ rb, err = unpackSOAResource(msg, off)
+ r = &rb
+ name = "SOA"
+ case TypePTR:
+ var rb PTRResource
+ rb, err = unpackPTRResource(msg, off)
+ r = &rb
+ name = "PTR"
+ case TypeMX:
+ var rb MXResource
+ rb, err = unpackMXResource(msg, off)
+ r = &rb
+ name = "MX"
+ case TypeTXT:
+ var rb TXTResource
+ rb, err = unpackTXTResource(msg, off, hdr.Length)
+ r = &rb
+ name = "TXT"
+ case TypeAAAA:
+ var rb AAAAResource
+ rb, err = unpackAAAAResource(msg, off)
+ r = &rb
+ name = "AAAA"
+ case TypeSRV:
+ var rb SRVResource
+ rb, err = unpackSRVResource(msg, off)
+ r = &rb
+ name = "SRV"
+ }
+ if err != nil {
+ return nil, off, &nestedError{name + " record", err}
+ }
+ if r == nil {
+ return nil, off, errors.New("invalid resource type: " + string(hdr.Type+'0'))
+ }
+ return r, off + int(hdr.Length), nil
+}
+
+// A CNAMEResource is a CNAME Resource record.
+type CNAMEResource struct {
+ CNAME Name
+}
+
+func (r *CNAMEResource) realType() Type {
+ return TypeCNAME
+}
+
+// pack appends the wire format of the CNAMEResource to msg.
+func (r *CNAMEResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ return r.CNAME.pack(msg, compression, compressionOff)
+}
+
+func unpackCNAMEResource(msg []byte, off int) (CNAMEResource, error) {
+ var cname Name
+ if _, err := cname.unpack(msg, off); err != nil {
+ return CNAMEResource{}, err
+ }
+ return CNAMEResource{cname}, nil
+}
+
+// An MXResource is an MX Resource record.
+type MXResource struct {
+ Pref uint16
+ MX Name
+}
+
+func (r *MXResource) realType() Type {
+ return TypeMX
+}
+
+// pack appends the wire format of the MXResource to msg.
+func (r *MXResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ oldMsg := msg
+ msg = packUint16(msg, r.Pref)
+ msg, err := r.MX.pack(msg, compression, compressionOff)
+ if err != nil {
+ return oldMsg, &nestedError{"MXResource.MX", err}
+ }
+ return msg, nil
+}
+
+func unpackMXResource(msg []byte, off int) (MXResource, error) {
+ pref, off, err := unpackUint16(msg, off)
+ if err != nil {
+ return MXResource{}, &nestedError{"Pref", err}
+ }
+ var mx Name
+ if _, err := mx.unpack(msg, off); err != nil {
+ return MXResource{}, &nestedError{"MX", err}
+ }
+ return MXResource{pref, mx}, nil
+}
+
+// An NSResource is an NS Resource record.
+type NSResource struct {
+ NS Name
+}
+
+func (r *NSResource) realType() Type {
+ return TypeNS
+}
+
+// pack appends the wire format of the NSResource to msg.
+func (r *NSResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ return r.NS.pack(msg, compression, compressionOff)
+}
+
+func unpackNSResource(msg []byte, off int) (NSResource, error) {
+ var ns Name
+ if _, err := ns.unpack(msg, off); err != nil {
+ return NSResource{}, err
+ }
+ return NSResource{ns}, nil
+}
+
+// A PTRResource is a PTR Resource record.
+type PTRResource struct {
+ PTR Name
+}
+
+func (r *PTRResource) realType() Type {
+ return TypePTR
+}
+
+// pack appends the wire format of the PTRResource to msg.
+func (r *PTRResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ return r.PTR.pack(msg, compression, compressionOff)
+}
+
+func unpackPTRResource(msg []byte, off int) (PTRResource, error) {
+ var ptr Name
+ if _, err := ptr.unpack(msg, off); err != nil {
+ return PTRResource{}, err
+ }
+ return PTRResource{ptr}, nil
+}
+
+// An SOAResource is an SOA Resource record.
+type SOAResource struct {
+ NS Name
+ MBox Name
+ Serial uint32
+ Refresh uint32
+ Retry uint32
+ Expire uint32
+
+ // MinTTL the is the default TTL of Resources records which did not
+ // contain a TTL value and the TTL of negative responses. (RFC 2308
+ // Section 4)
+ MinTTL uint32
+}
+
+func (r *SOAResource) realType() Type {
+ return TypeSOA
+}
+
+// pack appends the wire format of the SOAResource to msg.
+func (r *SOAResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ oldMsg := msg
+ msg, err := r.NS.pack(msg, compression, compressionOff)
+ if err != nil {
+ return oldMsg, &nestedError{"SOAResource.NS", err}
+ }
+ msg, err = r.MBox.pack(msg, compression, compressionOff)
+ if err != nil {
+ return oldMsg, &nestedError{"SOAResource.MBox", err}
+ }
+ msg = packUint32(msg, r.Serial)
+ msg = packUint32(msg, r.Refresh)
+ msg = packUint32(msg, r.Retry)
+ msg = packUint32(msg, r.Expire)
+ return packUint32(msg, r.MinTTL), nil
+}
+
+func unpackSOAResource(msg []byte, off int) (SOAResource, error) {
+ var ns Name
+ off, err := ns.unpack(msg, off)
+ if err != nil {
+ return SOAResource{}, &nestedError{"NS", err}
+ }
+ var mbox Name
+ if off, err = mbox.unpack(msg, off); err != nil {
+ return SOAResource{}, &nestedError{"MBox", err}
+ }
+ serial, off, err := unpackUint32(msg, off)
+ if err != nil {
+ return SOAResource{}, &nestedError{"Serial", err}
+ }
+ refresh, off, err := unpackUint32(msg, off)
+ if err != nil {
+ return SOAResource{}, &nestedError{"Refresh", err}
+ }
+ retry, off, err := unpackUint32(msg, off)
+ if err != nil {
+ return SOAResource{}, &nestedError{"Retry", err}
+ }
+ expire, off, err := unpackUint32(msg, off)
+ if err != nil {
+ return SOAResource{}, &nestedError{"Expire", err}
+ }
+ minTTL, _, err := unpackUint32(msg, off)
+ if err != nil {
+ return SOAResource{}, &nestedError{"MinTTL", err}
+ }
+ return SOAResource{ns, mbox, serial, refresh, retry, expire, minTTL}, nil
+}
+
+// A TXTResource is a TXT Resource record.
+type TXTResource struct {
+ TXT []string
+}
+
+func (r *TXTResource) realType() Type {
+ return TypeTXT
+}
+
+// pack appends the wire format of the TXTResource to msg.
+func (r *TXTResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ oldMsg := msg
+ for _, s := range r.TXT {
+ var err error
+ msg, err = packText(msg, s)
+ if err != nil {
+ return oldMsg, err
+ }
+ }
+ return msg, nil
+}
+
+func unpackTXTResource(msg []byte, off int, length uint16) (TXTResource, error) {
+ txts := make([]string, 0, 1)
+ for n := uint16(0); n < length; {
+ var t string
+ var err error
+ if t, off, err = unpackText(msg, off); err != nil {
+ return TXTResource{}, &nestedError{"text", err}
+ }
+ // Check if we got too many bytes.
+ if length-n < uint16(len(t))+1 {
+ return TXTResource{}, errCalcLen
+ }
+ n += uint16(len(t)) + 1
+ txts = append(txts, t)
+ }
+ return TXTResource{txts}, nil
+}
+
+// An SRVResource is an SRV Resource record.
+type SRVResource struct {
+ Priority uint16
+ Weight uint16
+ Port uint16
+ Target Name // Not compressed as per RFC 2782.
+}
+
+func (r *SRVResource) realType() Type {
+ return TypeSRV
+}
+
+// pack appends the wire format of the SRVResource to msg.
+func (r *SRVResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ oldMsg := msg
+ msg = packUint16(msg, r.Priority)
+ msg = packUint16(msg, r.Weight)
+ msg = packUint16(msg, r.Port)
+ msg, err := r.Target.pack(msg, nil, compressionOff)
+ if err != nil {
+ return oldMsg, &nestedError{"SRVResource.Target", err}
+ }
+ return msg, nil
+}
+
+func unpackSRVResource(msg []byte, off int) (SRVResource, error) {
+ priority, off, err := unpackUint16(msg, off)
+ if err != nil {
+ return SRVResource{}, &nestedError{"Priority", err}
+ }
+ weight, off, err := unpackUint16(msg, off)
+ if err != nil {
+ return SRVResource{}, &nestedError{"Weight", err}
+ }
+ port, off, err := unpackUint16(msg, off)
+ if err != nil {
+ return SRVResource{}, &nestedError{"Port", err}
+ }
+ var target Name
+ if _, err := target.unpack(msg, off); err != nil {
+ return SRVResource{}, &nestedError{"Target", err}
+ }
+ return SRVResource{priority, weight, port, target}, nil
+}
+
+// An AResource is an A Resource record.
+type AResource struct {
+ A [4]byte
+}
+
+func (r *AResource) realType() Type {
+ return TypeA
+}
+
+// pack appends the wire format of the AResource to msg.
+func (r *AResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ return packBytes(msg, r.A[:]), nil
+}
+
+func unpackAResource(msg []byte, off int) (AResource, error) {
+ var a [4]byte
+ if _, err := unpackBytes(msg, off, a[:]); err != nil {
+ return AResource{}, err
+ }
+ return AResource{a}, nil
+}
+
+// An AAAAResource is an AAAA Resource record.
+type AAAAResource struct {
+ AAAA [16]byte
+}
+
+func (r *AAAAResource) realType() Type {
+ return TypeAAAA
+}
+
+// pack appends the wire format of the AAAAResource to msg.
+func (r *AAAAResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) {
+ return packBytes(msg, r.AAAA[:]), nil
+}
+
+func unpackAAAAResource(msg []byte, off int) (AAAAResource, error) {
+ var aaaa [16]byte
+ if _, err := unpackBytes(msg, off, aaaa[:]); err != nil {
+ return AAAAResource{}, err
+ }
+ return AAAAResource{aaaa}, nil
+}
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dnsmessage
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func mustNewName(name string) Name {
+ n, err := NewName(name)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (m *Message) String() string {
+ s := fmt.Sprintf("Message: %#v\n", &m.Header)
+ if len(m.Questions) > 0 {
+ s += "-- Questions\n"
+ for _, q := range m.Questions {
+ s += fmt.Sprintf("%#v\n", q)
+ }
+ }
+ if len(m.Answers) > 0 {
+ s += "-- Answers\n"
+ for _, a := range m.Answers {
+ s += fmt.Sprintf("%#v\n", a)
+ }
+ }
+ if len(m.Authorities) > 0 {
+ s += "-- Authorities\n"
+ for _, ns := range m.Authorities {
+ s += fmt.Sprintf("%#v\n", ns)
+ }
+ }
+ if len(m.Additionals) > 0 {
+ s += "-- Additionals\n"
+ for _, e := range m.Additionals {
+ s += fmt.Sprintf("%#v\n", e)
+ }
+ }
+ return s
+}
+
+func TestNameString(t *testing.T) {
+ want := "foo"
+ name := mustNewName(want)
+ if got := fmt.Sprint(name); got != want {
+ t.Errorf("got fmt.Sprint(%#v) = %s, want = %s", name, got, want)
+ }
+}
+
+func TestQuestionPackUnpack(t *testing.T) {
+ want := Question{
+ Name: mustNewName("."),
+ Type: TypeA,
+ Class: ClassINET,
+ }
+ buf, err := want.pack(make([]byte, 1, 50), map[string]int{}, 1)
+ if err != nil {
+ t.Fatal("Packing failed:", err)
+ }
+ var p Parser
+ p.msg = buf
+ p.header.questions = 1
+ p.section = sectionQuestions
+ p.off = 1
+ got, err := p.Question()
+ if err != nil {
+ t.Fatalf("Unpacking failed: %v\n%s", err, string(buf[1:]))
+ }
+ if p.off != len(buf) {
+ t.Errorf("Unpacked different amount than packed: got n = %d, want = %d", p.off, len(buf))
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Got = %+v, want = %+v", got, want)
+ }
+}
+
+func TestName(t *testing.T) {
+ tests := []string{
+ "",
+ ".",
+ "google..com",
+ "google.com",
+ "google..com.",
+ "google.com.",
+ ".google.com.",
+ "www..google.com.",
+ "www.google.com.",
+ }
+
+ for _, test := range tests {
+ n, err := NewName(test)
+ if err != nil {
+ t.Errorf("Creating name for %q: %v", test, err)
+ continue
+ }
+ if ns := n.String(); ns != test {
+ t.Errorf("Got %#v.String() = %q, want = %q", n, ns, test)
+ continue
+ }
+ }
+}
+
+func TestNamePackUnpack(t *testing.T) {
+ tests := []struct {
+ in string
+ want string
+ err error
+ }{
+ {"", "", errNonCanonicalName},
+ {".", ".", nil},
+ {"google..com", "", errNonCanonicalName},
+ {"google.com", "", errNonCanonicalName},
+ {"google..com.", "", errZeroSegLen},
+ {"google.com.", "google.com.", nil},
+ {".google.com.", "", errZeroSegLen},
+ {"www..google.com.", "", errZeroSegLen},
+ {"www.google.com.", "www.google.com.", nil},
+ }
+
+ for _, test := range tests {
+ in := mustNewName(test.in)
+ want := mustNewName(test.want)
+ buf, err := in.pack(make([]byte, 0, 30), map[string]int{}, 0)
+ if err != test.err {
+ t.Errorf("Packing of %q: got err = %v, want err = %v", test.in, err, test.err)
+ continue
+ }
+ if test.err != nil {
+ continue
+ }
+ var got Name
+ n, err := got.unpack(buf, 0)
+ if err != nil {
+ t.Errorf("Unpacking for %q failed: %v", test.in, err)
+ continue
+ }
+ if n != len(buf) {
+ t.Errorf(
+ "Unpacked different amount than packed for %q: got n = %d, want = %d",
+ test.in,
+ n,
+ len(buf),
+ )
+ }
+ if got != want {
+ t.Errorf("Unpacking packing of %q: got = %#v, want = %#v", test.in, got, want)
+ }
+ }
+}
+
+func checkErrorPrefix(err error, prefix string) bool {
+ e, ok := err.(*nestedError)
+ return ok && e.s == prefix
+}
+
+func TestHeaderUnpackError(t *testing.T) {
+ wants := []string{
+ "id",
+ "bits",
+ "questions",
+ "answers",
+ "authorities",
+ "additionals",
+ }
+ var buf []byte
+ var h header
+ for _, want := range wants {
+ n, err := h.unpack(buf, 0)
+ if n != 0 || !checkErrorPrefix(err, want) {
+ t.Errorf("got h.unpack([%d]byte, 0) = %d, %v, want = 0, %s", len(buf), n, err, want)
+ }
+ buf = append(buf, 0, 0)
+ }
+}
+
+func TestParserStart(t *testing.T) {
+ const want = "unpacking header"
+ var p Parser
+ for i := 0; i <= 1; i++ {
+ _, err := p.Start([]byte{})
+ if !checkErrorPrefix(err, want) {
+ t.Errorf("got p.Start(nil) = _, %v, want = _, %s", err, want)
+ }
+ }
+}
+
+func TestResourceNotStarted(t *testing.T) {
+ tests := []struct {
+ name string
+ fn func(*Parser) error
+ }{
+ {"CNAMEResource", func(p *Parser) error { _, err := p.CNAMEResource(); return err }},
+ {"MXResource", func(p *Parser) error { _, err := p.MXResource(); return err }},
+ {"NSResource", func(p *Parser) error { _, err := p.NSResource(); return err }},
+ {"PTRResource", func(p *Parser) error { _, err := p.PTRResource(); return err }},
+ {"SOAResource", func(p *Parser) error { _, err := p.SOAResource(); return err }},
+ {"TXTResource", func(p *Parser) error { _, err := p.TXTResource(); return err }},
+ {"SRVResource", func(p *Parser) error { _, err := p.SRVResource(); return err }},
+ {"AResource", func(p *Parser) error { _, err := p.AResource(); return err }},
+ {"AAAAResource", func(p *Parser) error { _, err := p.AAAAResource(); return err }},
+ }
+
+ for _, test := range tests {
+ if err := test.fn(&Parser{}); err != ErrNotStarted {
+ t.Errorf("got _, %v = p.%s(), want = _, %v", err, test.name, ErrNotStarted)
+ }
+ }
+}
+
+func TestDNSPackUnpack(t *testing.T) {
+ wants := []Message{
+ {
+ Questions: []Question{
+ {
+ Name: mustNewName("."),
+ Type: TypeAAAA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{},
+ Authorities: []Resource{},
+ Additionals: []Resource{},
+ },
+ largeTestMsg(),
+ }
+ for i, want := range wants {
+ b, err := want.Pack()
+ if err != nil {
+ t.Fatalf("%d: packing failed: %v", i, err)
+ }
+ var got Message
+ err = got.Unpack(b)
+ if err != nil {
+ t.Fatalf("%d: unpacking failed: %v", i, err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("%d: got = %+v, want = %+v", i, &got, &want)
+ }
+ }
+}
+
+func TestDNSAppendPackUnpack(t *testing.T) {
+ wants := []Message{
+ {
+ Questions: []Question{
+ {
+ Name: mustNewName("."),
+ Type: TypeAAAA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{},
+ Authorities: []Resource{},
+ Additionals: []Resource{},
+ },
+ largeTestMsg(),
+ }
+ for i, want := range wants {
+ b := make([]byte, 2, 514)
+ b, err := want.AppendPack(b)
+ if err != nil {
+ t.Fatalf("%d: packing failed: %v", i, err)
+ }
+ b = b[2:]
+ var got Message
+ err = got.Unpack(b)
+ if err != nil {
+ t.Fatalf("%d: unpacking failed: %v", i, err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("%d: got = %+v, want = %+v", i, &got, &want)
+ }
+ }
+}
+
+func TestSkipAll(t *testing.T) {
+ msg := largeTestMsg()
+ buf, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Packing large test message:", err)
+ }
+ var p Parser
+ if _, err := p.Start(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ name string
+ f func() error
+ }{
+ {"SkipAllQuestions", p.SkipAllQuestions},
+ {"SkipAllAnswers", p.SkipAllAnswers},
+ {"SkipAllAuthorities", p.SkipAllAuthorities},
+ {"SkipAllAdditionals", p.SkipAllAdditionals},
+ }
+ for _, test := range tests {
+ for i := 1; i <= 3; i++ {
+ if err := test.f(); err != nil {
+ t.Errorf("Call #%d to %s(): %v", i, test.name, err)
+ }
+ }
+ }
+}
+
+func TestSkipEach(t *testing.T) {
+ msg := smallTestMsg()
+
+ buf, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Packing test message:", err)
+ }
+ var p Parser
+ if _, err := p.Start(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ name string
+ f func() error
+ }{
+ {"SkipQuestion", p.SkipQuestion},
+ {"SkipAnswer", p.SkipAnswer},
+ {"SkipAuthority", p.SkipAuthority},
+ {"SkipAdditional", p.SkipAdditional},
+ }
+ for _, test := range tests {
+ if err := test.f(); err != nil {
+ t.Errorf("First call: got %s() = %v, want = %v", test.name, err, nil)
+ }
+ if err := test.f(); err != ErrSectionDone {
+ t.Errorf("Second call: got %s() = %v, want = %v", test.name, err, ErrSectionDone)
+ }
+ }
+}
+
+func TestSkipAfterRead(t *testing.T) {
+ msg := smallTestMsg()
+
+ buf, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Packing test message:", err)
+ }
+ var p Parser
+ if _, err := p.Start(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ name string
+ skip func() error
+ read func() error
+ }{
+ {"Question", p.SkipQuestion, func() error { _, err := p.Question(); return err }},
+ {"Answer", p.SkipAnswer, func() error { _, err := p.Answer(); return err }},
+ {"Authority", p.SkipAuthority, func() error { _, err := p.Authority(); return err }},
+ {"Additional", p.SkipAdditional, func() error { _, err := p.Additional(); return err }},
+ }
+ for _, test := range tests {
+ if err := test.read(); err != nil {
+ t.Errorf("Got %s() = _, %v, want = _, %v", test.name, err, nil)
+ }
+ if err := test.skip(); err != ErrSectionDone {
+ t.Errorf("Got Skip%s() = %v, want = %v", test.name, err, ErrSectionDone)
+ }
+ }
+}
+
+func TestSkipNotStarted(t *testing.T) {
+ var p Parser
+
+ tests := []struct {
+ name string
+ f func() error
+ }{
+ {"SkipAllQuestions", p.SkipAllQuestions},
+ {"SkipAllAnswers", p.SkipAllAnswers},
+ {"SkipAllAuthorities", p.SkipAllAuthorities},
+ {"SkipAllAdditionals", p.SkipAllAdditionals},
+ }
+ for _, test := range tests {
+ if err := test.f(); err != ErrNotStarted {
+ t.Errorf("Got %s() = %v, want = %v", test.name, err, ErrNotStarted)
+ }
+ }
+}
+
+func TestTooManyRecords(t *testing.T) {
+ const recs = int(^uint16(0)) + 1
+ tests := []struct {
+ name string
+ msg Message
+ want error
+ }{
+ {
+ "Questions",
+ Message{
+ Questions: make([]Question, recs),
+ },
+ errTooManyQuestions,
+ },
+ {
+ "Answers",
+ Message{
+ Answers: make([]Resource, recs),
+ },
+ errTooManyAnswers,
+ },
+ {
+ "Authorities",
+ Message{
+ Authorities: make([]Resource, recs),
+ },
+ errTooManyAuthorities,
+ },
+ {
+ "Additionals",
+ Message{
+ Additionals: make([]Resource, recs),
+ },
+ errTooManyAdditionals,
+ },
+ }
+
+ for _, test := range tests {
+ if _, got := test.msg.Pack(); got != test.want {
+ t.Errorf("Packing %d %s: got = %v, want = %v", recs, test.name, got, test.want)
+ }
+ }
+}
+
+func TestVeryLongTxt(t *testing.T) {
+ want := Resource{
+ ResourceHeader{
+ Name: mustNewName("foo.bar.example.com."),
+ Type: TypeTXT,
+ Class: ClassINET,
+ },
+ &TXTResource{[]string{
+ "",
+ "",
+ "foo bar",
+ "",
+ "www.example.com",
+ "www.example.com.",
+ strings.Repeat(".", 255),
+ }},
+ }
+ buf, err := want.pack(make([]byte, 0, 8000), map[string]int{}, 0)
+ if err != nil {
+ t.Fatal("Packing failed:", err)
+ }
+ var got Resource
+ off, err := got.Header.unpack(buf, 0)
+ if err != nil {
+ t.Fatal("Unpacking ResourceHeader failed:", err)
+ }
+ body, n, err := unpackResourceBody(buf, off, got.Header)
+ if err != nil {
+ t.Fatal("Unpacking failed:", err)
+ }
+ got.Body = body
+ if n != len(buf) {
+ t.Errorf("Unpacked different amount than packed: got n = %d, want = %d", n, len(buf))
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Got = %#v, want = %#v", got, want)
+ }
+}
+
+func TestTooLongTxt(t *testing.T) {
+ rb := TXTResource{[]string{strings.Repeat(".", 256)}}
+ if _, err := rb.pack(make([]byte, 0, 8000), map[string]int{}, 0); err != errStringTooLong {
+ t.Errorf("Packing TXTRecord with 256 character string: got err = %v, want = %v", err, errStringTooLong)
+ }
+}
+
+func TestStartAppends(t *testing.T) {
+ buf := make([]byte, 2, 514)
+ wantBuf := []byte{4, 44}
+ copy(buf, wantBuf)
+
+ b := NewBuilder(buf, Header{})
+ b.EnableCompression()
+
+ buf, err := b.Finish()
+ if err != nil {
+ t.Fatal("Building failed:", err)
+ }
+ if got, want := len(buf), headerLen+2; got != want {
+ t.Errorf("Got len(buf} = %d, want = %d", got, want)
+ }
+ if string(buf[:2]) != string(wantBuf) {
+ t.Errorf("Original data not preserved, got = %v, want = %v", buf[:2], wantBuf)
+ }
+}
+
+func TestStartError(t *testing.T) {
+ tests := []struct {
+ name string
+ fn func(*Builder) error
+ }{
+ {"Questions", func(b *Builder) error { return b.StartQuestions() }},
+ {"Answers", func(b *Builder) error { return b.StartAnswers() }},
+ {"Authorities", func(b *Builder) error { return b.StartAuthorities() }},
+ {"Additionals", func(b *Builder) error { return b.StartAdditionals() }},
+ }
+
+ envs := []struct {
+ name string
+ fn func() *Builder
+ want error
+ }{
+ {"sectionNotStarted", func() *Builder { return &Builder{section: sectionNotStarted} }, ErrNotStarted},
+ {"sectionDone", func() *Builder { return &Builder{section: sectionDone} }, ErrSectionDone},
+ }
+
+ for _, env := range envs {
+ for _, test := range tests {
+ if got := test.fn(env.fn()); got != env.want {
+ t.Errorf("got Builder{%s}.Start%s = %v, want = %v", env.name, test.name, got, env.want)
+ }
+ }
+ }
+}
+
+func TestBuilderResourceError(t *testing.T) {
+ tests := []struct {
+ name string
+ fn func(*Builder) error
+ }{
+ {"CNAMEResource", func(b *Builder) error { return b.CNAMEResource(ResourceHeader{}, CNAMEResource{}) }},
+ {"MXResource", func(b *Builder) error { return b.MXResource(ResourceHeader{}, MXResource{}) }},
+ {"NSResource", func(b *Builder) error { return b.NSResource(ResourceHeader{}, NSResource{}) }},
+ {"PTRResource", func(b *Builder) error { return b.PTRResource(ResourceHeader{}, PTRResource{}) }},
+ {"SOAResource", func(b *Builder) error { return b.SOAResource(ResourceHeader{}, SOAResource{}) }},
+ {"TXTResource", func(b *Builder) error { return b.TXTResource(ResourceHeader{}, TXTResource{}) }},
+ {"SRVResource", func(b *Builder) error { return b.SRVResource(ResourceHeader{}, SRVResource{}) }},
+ {"AResource", func(b *Builder) error { return b.AResource(ResourceHeader{}, AResource{}) }},
+ {"AAAAResource", func(b *Builder) error { return b.AAAAResource(ResourceHeader{}, AAAAResource{}) }},
+ }
+
+ envs := []struct {
+ name string
+ fn func() *Builder
+ want error
+ }{
+ {"sectionNotStarted", func() *Builder { return &Builder{section: sectionNotStarted} }, ErrNotStarted},
+ {"sectionHeader", func() *Builder { return &Builder{section: sectionHeader} }, ErrNotStarted},
+ {"sectionQuestions", func() *Builder { return &Builder{section: sectionQuestions} }, ErrNotStarted},
+ {"sectionDone", func() *Builder { return &Builder{section: sectionDone} }, ErrSectionDone},
+ }
+
+ for _, env := range envs {
+ for _, test := range tests {
+ if got := test.fn(env.fn()); got != env.want {
+ t.Errorf("got Builder{%s}.%s = %v, want = %v", env.name, test.name, got, env.want)
+ }
+ }
+ }
+}
+
+func TestFinishError(t *testing.T) {
+ var b Builder
+ want := ErrNotStarted
+ if _, got := b.Finish(); got != want {
+ t.Errorf("got Builder{}.Finish() = %v, want = %v", got, want)
+ }
+}
+
+func TestBuilder(t *testing.T) {
+ msg := largeTestMsg()
+ want, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Packing without builder:", err)
+ }
+
+ b := NewBuilder(nil, msg.Header)
+ b.EnableCompression()
+
+ if err := b.StartQuestions(); err != nil {
+ t.Fatal("b.StartQuestions():", err)
+ }
+ for _, q := range msg.Questions {
+ if err := b.Question(q); err != nil {
+ t.Fatalf("b.Question(%#v): %v", q, err)
+ }
+ }
+
+ if err := b.StartAnswers(); err != nil {
+ t.Fatal("b.StartAnswers():", err)
+ }
+ for _, a := range msg.Answers {
+ switch a.Header.Type {
+ case TypeA:
+ if err := b.AResource(a.Header, *a.Body.(*AResource)); err != nil {
+ t.Fatalf("b.AResource(%#v): %v", a, err)
+ }
+ case TypeNS:
+ if err := b.NSResource(a.Header, *a.Body.(*NSResource)); err != nil {
+ t.Fatalf("b.NSResource(%#v): %v", a, err)
+ }
+ case TypeCNAME:
+ if err := b.CNAMEResource(a.Header, *a.Body.(*CNAMEResource)); err != nil {
+ t.Fatalf("b.CNAMEResource(%#v): %v", a, err)
+ }
+ case TypeSOA:
+ if err := b.SOAResource(a.Header, *a.Body.(*SOAResource)); err != nil {
+ t.Fatalf("b.SOAResource(%#v): %v", a, err)
+ }
+ case TypePTR:
+ if err := b.PTRResource(a.Header, *a.Body.(*PTRResource)); err != nil {
+ t.Fatalf("b.PTRResource(%#v): %v", a, err)
+ }
+ case TypeMX:
+ if err := b.MXResource(a.Header, *a.Body.(*MXResource)); err != nil {
+ t.Fatalf("b.MXResource(%#v): %v", a, err)
+ }
+ case TypeTXT:
+ if err := b.TXTResource(a.Header, *a.Body.(*TXTResource)); err != nil {
+ t.Fatalf("b.TXTResource(%#v): %v", a, err)
+ }
+ case TypeAAAA:
+ if err := b.AAAAResource(a.Header, *a.Body.(*AAAAResource)); err != nil {
+ t.Fatalf("b.AAAAResource(%#v): %v", a, err)
+ }
+ case TypeSRV:
+ if err := b.SRVResource(a.Header, *a.Body.(*SRVResource)); err != nil {
+ t.Fatalf("b.SRVResource(%#v): %v", a, err)
+ }
+ }
+ }
+
+ if err := b.StartAuthorities(); err != nil {
+ t.Fatal("b.StartAuthorities():", err)
+ }
+ for _, a := range msg.Authorities {
+ if err := b.NSResource(a.Header, *a.Body.(*NSResource)); err != nil {
+ t.Fatalf("b.NSResource(%#v): %v", a, err)
+ }
+ }
+
+ if err := b.StartAdditionals(); err != nil {
+ t.Fatal("b.StartAdditionals():", err)
+ }
+ for _, a := range msg.Additionals {
+ if err := b.TXTResource(a.Header, *a.Body.(*TXTResource)); err != nil {
+ t.Fatalf("b.TXTResource(%#v): %v", a, err)
+ }
+ }
+
+ got, err := b.Finish()
+ if err != nil {
+ t.Fatal("b.Finish():", err)
+ }
+ if !bytes.Equal(got, want) {
+ t.Fatalf("Got from Builder: %#v\nwant = %#v", got, want)
+ }
+}
+
+func TestResourcePack(t *testing.T) {
+ for _, tt := range []struct {
+ m Message
+ err error
+ }{
+ {
+ Message{
+ Questions: []Question{
+ {
+ Name: mustNewName("."),
+ Type: TypeAAAA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{{ResourceHeader{}, nil}},
+ },
+ &nestedError{"packing Answer", errNilResouceBody},
+ },
+ {
+ Message{
+ Questions: []Question{
+ {
+ Name: mustNewName("."),
+ Type: TypeAAAA,
+ Class: ClassINET,
+ },
+ },
+ Authorities: []Resource{{ResourceHeader{}, (*NSResource)(nil)}},
+ },
+ &nestedError{"packing Authority",
+ &nestedError{"ResourceHeader",
+ &nestedError{"Name", errNonCanonicalName},
+ },
+ },
+ },
+ {
+ Message{
+ Questions: []Question{
+ {
+ Name: mustNewName("."),
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ },
+ Additionals: []Resource{{ResourceHeader{}, nil}},
+ },
+ &nestedError{"packing Additional", errNilResouceBody},
+ },
+ } {
+ _, err := tt.m.Pack()
+ if !reflect.DeepEqual(err, tt.err) {
+ t.Errorf("got %v for %v; want %v", err, tt.m, tt.err)
+ }
+ }
+}
+
+func benchmarkParsingSetup() ([]byte, error) {
+ name := mustNewName("foo.bar.example.com.")
+ msg := Message{
+ Header: Header{Response: true, Authoritative: true},
+ Questions: []Question{
+ {
+ Name: name,
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{
+ {
+ ResourceHeader{
+ Name: name,
+ Class: ClassINET,
+ },
+ &AResource{[4]byte{}},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Class: ClassINET,
+ },
+ &AAAAResource{[16]byte{}},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Class: ClassINET,
+ },
+ &CNAMEResource{name},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Class: ClassINET,
+ },
+ &NSResource{name},
+ },
+ },
+ }
+
+ buf, err := msg.Pack()
+ if err != nil {
+ return nil, fmt.Errorf("msg.Pack(): %v", err)
+ }
+ return buf, nil
+}
+
+func benchmarkParsing(tb testing.TB, buf []byte) {
+ var p Parser
+ if _, err := p.Start(buf); err != nil {
+ tb.Fatal("p.Start(buf):", err)
+ }
+
+ for {
+ _, err := p.Question()
+ if err == ErrSectionDone {
+ break
+ }
+ if err != nil {
+ tb.Fatal("p.Question():", err)
+ }
+ }
+
+ for {
+ h, err := p.AnswerHeader()
+ if err == ErrSectionDone {
+ break
+ }
+ if err != nil {
+ panic(err)
+ }
+
+ switch h.Type {
+ case TypeA:
+ if _, err := p.AResource(); err != nil {
+ tb.Fatal("p.AResource():", err)
+ }
+ case TypeAAAA:
+ if _, err := p.AAAAResource(); err != nil {
+ tb.Fatal("p.AAAAResource():", err)
+ }
+ case TypeCNAME:
+ if _, err := p.CNAMEResource(); err != nil {
+ tb.Fatal("p.CNAMEResource():", err)
+ }
+ case TypeNS:
+ if _, err := p.NSResource(); err != nil {
+ tb.Fatal("p.NSResource():", err)
+ }
+ default:
+ tb.Fatalf("unknown type: %T", h)
+ }
+ }
+}
+
+func BenchmarkParsing(b *testing.B) {
+ buf, err := benchmarkParsingSetup()
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ benchmarkParsing(b, buf)
+ }
+}
+
+func TestParsingAllocs(t *testing.T) {
+ buf, err := benchmarkParsingSetup()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if allocs := testing.AllocsPerRun(100, func() { benchmarkParsing(t, buf) }); allocs > 0.5 {
+ t.Errorf("Allocations during parsing: got = %f, want ~0", allocs)
+ }
+}
+
+func benchmarkBuildingSetup() (Name, []byte) {
+ name := mustNewName("foo.bar.example.com.")
+ buf := make([]byte, 0, packStartingCap)
+ return name, buf
+}
+
+func benchmarkBuilding(tb testing.TB, name Name, buf []byte) {
+ bld := NewBuilder(buf, Header{Response: true, Authoritative: true})
+
+ if err := bld.StartQuestions(); err != nil {
+ tb.Fatal("bld.StartQuestions():", err)
+ }
+ q := Question{
+ Name: name,
+ Type: TypeA,
+ Class: ClassINET,
+ }
+ if err := bld.Question(q); err != nil {
+ tb.Fatalf("bld.Question(%+v): %v", q, err)
+ }
+
+ hdr := ResourceHeader{
+ Name: name,
+ Class: ClassINET,
+ }
+ if err := bld.StartAnswers(); err != nil {
+ tb.Fatal("bld.StartQuestions():", err)
+ }
+
+ ar := AResource{[4]byte{}}
+ if err := bld.AResource(hdr, ar); err != nil {
+ tb.Fatalf("bld.AResource(%+v, %+v): %v", hdr, ar, err)
+ }
+
+ aaar := AAAAResource{[16]byte{}}
+ if err := bld.AAAAResource(hdr, aaar); err != nil {
+ tb.Fatalf("bld.AAAAResource(%+v, %+v): %v", hdr, aaar, err)
+ }
+
+ cnr := CNAMEResource{name}
+ if err := bld.CNAMEResource(hdr, cnr); err != nil {
+ tb.Fatalf("bld.CNAMEResource(%+v, %+v): %v", hdr, cnr, err)
+ }
+
+ nsr := NSResource{name}
+ if err := bld.NSResource(hdr, nsr); err != nil {
+ tb.Fatalf("bld.NSResource(%+v, %+v): %v", hdr, nsr, err)
+ }
+
+ if _, err := bld.Finish(); err != nil {
+ tb.Fatal("bld.Finish():", err)
+ }
+}
+
+func BenchmarkBuilding(b *testing.B) {
+ name, buf := benchmarkBuildingSetup()
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ benchmarkBuilding(b, name, buf)
+ }
+}
+
+func TestBuildingAllocs(t *testing.T) {
+ name, buf := benchmarkBuildingSetup()
+ if allocs := testing.AllocsPerRun(100, func() { benchmarkBuilding(t, name, buf) }); allocs > 0.5 {
+ t.Errorf("Allocations during building: got = %f, want ~0", allocs)
+ }
+}
+
+func smallTestMsg() Message {
+ name := mustNewName("example.com.")
+ return Message{
+ Header: Header{Response: true, Authoritative: true},
+ Questions: []Question{
+ {
+ Name: name,
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ &AResource{[4]byte{127, 0, 0, 1}},
+ },
+ },
+ Authorities: []Resource{
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ &AResource{[4]byte{127, 0, 0, 1}},
+ },
+ },
+ Additionals: []Resource{
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ &AResource{[4]byte{127, 0, 0, 1}},
+ },
+ },
+ }
+}
+
+func BenchmarkPack(b *testing.B) {
+ msg := largeTestMsg()
+
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ if _, err := msg.Pack(); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkAppendPack(b *testing.B) {
+ msg := largeTestMsg()
+ buf := make([]byte, 0, packStartingCap)
+
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ if _, err := msg.AppendPack(buf[:0]); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func largeTestMsg() Message {
+ name := mustNewName("foo.bar.example.com.")
+ return Message{
+ Header: Header{Response: true, Authoritative: true},
+ Questions: []Question{
+ {
+ Name: name,
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ &AResource{[4]byte{127, 0, 0, 1}},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ &AResource{[4]byte{127, 0, 0, 2}},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeAAAA,
+ Class: ClassINET,
+ },
+ &AAAAResource{[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeCNAME,
+ Class: ClassINET,
+ },
+ &CNAMEResource{mustNewName("alias.example.com.")},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeSOA,
+ Class: ClassINET,
+ },
+ &SOAResource{
+ NS: mustNewName("ns1.example.com."),
+ MBox: mustNewName("mb.example.com."),
+ Serial: 1,
+ Refresh: 2,
+ Retry: 3,
+ Expire: 4,
+ MinTTL: 5,
+ },
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypePTR,
+ Class: ClassINET,
+ },
+ &PTRResource{mustNewName("ptr.example.com.")},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeMX,
+ Class: ClassINET,
+ },
+ &MXResource{
+ 7,
+ mustNewName("mx.example.com."),
+ },
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeSRV,
+ Class: ClassINET,
+ },
+ &SRVResource{
+ 8,
+ 9,
+ 11,
+ mustNewName("srv.example.com."),
+ },
+ },
+ },
+ Authorities: []Resource{
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeNS,
+ Class: ClassINET,
+ },
+ &NSResource{mustNewName("ns1.example.com.")},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeNS,
+ Class: ClassINET,
+ },
+ &NSResource{mustNewName("ns2.example.com.")},
+ },
+ },
+ Additionals: []Resource{
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeTXT,
+ Class: ClassINET,
+ },
+ &TXTResource{[]string{"So Long, and Thanks for All the Fish"}},
+ },
+ {
+ ResourceHeader{
+ Name: name,
+ Type: TypeTXT,
+ Class: ClassINET,
+ },
+ &TXTResource{[]string{"Hamster Huey and the Gooey Kablooie"}},
+ },
+ },
+ }
+}