]> Cypherpunks repositories - gostls13.git/commitdiff
net: do not use reflect for DNS messages.
authorRémy Oudompheng <oudomphe@phare.normalesup.org>
Tue, 6 Mar 2012 07:02:39 +0000 (08:02 +0100)
committerRémy Oudompheng <oudomphe@phare.normalesup.org>
Tue, 6 Mar 2012 07:02:39 +0000 (08:02 +0100)
Fixes #3201.

R=bradfitz, bradfitz, rsc
CC=golang-dev, remy
https://golang.org/cl/5753045

src/pkg/go/build/deps_test.go
src/pkg/net/dnsmsg.go
src/pkg/net/dnsmsg_test.go

index d10bfa8f36e99daea6d0b1a813d7335d066c958f..89033e9c57bd021f3a91553d232d7e80f0035857 100644 (file)
@@ -226,8 +226,8 @@ var pkgDeps = map[string][]string{
        "os/user": {"L3", "CGO", "syscall"},
 
        // Basic networking.
-       // TODO: Remove reflect, possibly math/rand.
-       "net": {"L0", "CGO", "math/rand", "os", "reflect", "sort", "syscall", "time"},
+       // TODO: maybe remove math/rand.
+       "net": {"L0", "CGO", "math/rand", "os", "sort", "syscall", "time"},
 
        // NET enables use of basic network-related packages.
        "NET": {
index 4d1c8371ef0308775bddba2f5e3604f3f1c19baf..b6ebe117363b3a821e9078c56b68ee1d3937dd01 100644 (file)
@@ -7,11 +7,10 @@
 // This is intended to support name resolution during Dial.
 // It doesn't have to be blazing fast.
 //
-// Rather than write the usual handful of routines to pack and
-// unpack every message that can appear on the wire, we use
-// reflection to write a generic pack/unpack for structs and then
-// use it.  Thus, if in the future we need to define new message
-// structs, no new pack/unpack/printing code needs to be written.
+// Each message structure has a Walk method that is used by
+// a generic pack/unpack routine. Thus, if in the future we need
+// to define new message structs, no new pack/unpack/printing code
+// needs to be written.
 //
 // The first half of this file defines the DNS message formats.
 // The second half implements the conversion to and from wire format.
 
 package net
 
-import (
-       "reflect"
-)
-
 // Packet formats
 
 // Wire constants.
@@ -73,6 +68,20 @@ const (
        dnsRcodeRefused        = 5
 )
 
+// A dnsStruct describes how to iterate over its fields to emulate
+// reflective marshalling.
+type dnsStruct interface {
+       // Walk iterates over fields of a structure and calls f
+       // with a reference to that field, the name of the field
+       // and a tag ("", "domain", "ipv4", "ipv6") specifying
+       // particular encodings. Possible concrete types
+       // for v are *uint16, *uint32, *string, or []byte, and
+       // *int, *bool in the case of dnsMsgHdr.
+       // Whenever f returns false, Walk must stop and return
+       // false, and otherwise return true.
+       Walk(f func(v interface{}, name, tag string) (ok bool)) (ok bool)
+}
+
 // The wire format for the DNS packet header.
 type dnsHeader struct {
        Id                                 uint16
@@ -80,6 +89,15 @@ type dnsHeader struct {
        Qdcount, Ancount, Nscount, Arcount uint16
 }
 
+func (h *dnsHeader) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return f(&h.Id, "Id", "") &&
+               f(&h.Bits, "Bits", "") &&
+               f(&h.Qdcount, "Qdcount", "") &&
+               f(&h.Ancount, "Ancount", "") &&
+               f(&h.Nscount, "Nscount", "") &&
+               f(&h.Arcount, "Arcount", "")
+}
+
 const (
        // dnsHeader.Bits
        _QR = 1 << 15 // query/response (response=1)
@@ -96,6 +114,12 @@ type dnsQuestion struct {
        Qclass uint16
 }
 
+func (q *dnsQuestion) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return f(&q.Name, "Name", "domain") &&
+               f(&q.Qtype, "Qtype", "") &&
+               f(&q.Qclass, "Qclass", "")
+}
+
 // DNS responses (resource records).
 // There are many types of messages,
 // but they all share the same header.
@@ -111,7 +135,16 @@ func (h *dnsRR_Header) Header() *dnsRR_Header {
        return h
 }
 
+func (h *dnsRR_Header) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return f(&h.Name, "Name", "domain") &&
+               f(&h.Rrtype, "Rrtype", "") &&
+               f(&h.Class, "Class", "") &&
+               f(&h.Ttl, "Ttl", "") &&
+               f(&h.Rdlength, "Rdlength", "")
+}
+
 type dnsRR interface {
+       dnsStruct
        Header() *dnsRR_Header
 }
 
@@ -126,6 +159,10 @@ func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_CNAME) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Cname, "Cname", "domain")
+}
+
 type dnsRR_HINFO struct {
        Hdr dnsRR_Header
        Cpu string
@@ -136,6 +173,10 @@ func (rr *dnsRR_HINFO) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_HINFO) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Cpu, "Cpu", "") && f(&rr.Os, "Os", "")
+}
+
 type dnsRR_MB struct {
        Hdr dnsRR_Header
        Mb  string `net:"domain-name"`
@@ -145,6 +186,10 @@ func (rr *dnsRR_MB) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MB) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Mb, "Mb", "domain")
+}
+
 type dnsRR_MG struct {
        Hdr dnsRR_Header
        Mg  string `net:"domain-name"`
@@ -154,6 +199,10 @@ func (rr *dnsRR_MG) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MG) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Mg, "Mg", "domain")
+}
+
 type dnsRR_MINFO struct {
        Hdr   dnsRR_Header
        Rmail string `net:"domain-name"`
@@ -164,6 +213,10 @@ func (rr *dnsRR_MINFO) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MINFO) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Rmail, "Rmail", "domain") && f(&rr.Email, "Email", "domain")
+}
+
 type dnsRR_MR struct {
        Hdr dnsRR_Header
        Mr  string `net:"domain-name"`
@@ -173,6 +226,10 @@ func (rr *dnsRR_MR) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MR) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Mr, "Mr", "domain")
+}
+
 type dnsRR_MX struct {
        Hdr  dnsRR_Header
        Pref uint16
@@ -183,6 +240,10 @@ func (rr *dnsRR_MX) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_MX) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Pref, "Pref", "") && f(&rr.Mx, "Mx", "domain")
+}
+
 type dnsRR_NS struct {
        Hdr dnsRR_Header
        Ns  string `net:"domain-name"`
@@ -192,6 +253,10 @@ func (rr *dnsRR_NS) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_NS) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Ns, "Ns", "domain")
+}
+
 type dnsRR_PTR struct {
        Hdr dnsRR_Header
        Ptr string `net:"domain-name"`
@@ -201,6 +266,10 @@ func (rr *dnsRR_PTR) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_PTR) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Ptr, "Ptr", "domain")
+}
+
 type dnsRR_SOA struct {
        Hdr     dnsRR_Header
        Ns      string `net:"domain-name"`
@@ -216,6 +285,17 @@ func (rr *dnsRR_SOA) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_SOA) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) &&
+               f(&rr.Ns, "Ns", "domain") &&
+               f(&rr.Mbox, "Mbox", "domain") &&
+               f(&rr.Serial, "Serial", "") &&
+               f(&rr.Refresh, "Refresh", "") &&
+               f(&rr.Retry, "Retry", "") &&
+               f(&rr.Expire, "Expire", "") &&
+               f(&rr.Minttl, "Minttl", "")
+}
+
 type dnsRR_TXT struct {
        Hdr dnsRR_Header
        Txt string // not domain name
@@ -225,6 +305,10 @@ func (rr *dnsRR_TXT) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_TXT) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.Txt, "Txt", "")
+}
+
 type dnsRR_SRV struct {
        Hdr      dnsRR_Header
        Priority uint16
@@ -237,6 +321,14 @@ func (rr *dnsRR_SRV) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_SRV) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) &&
+               f(&rr.Priority, "Priority", "") &&
+               f(&rr.Weight, "Weight", "") &&
+               f(&rr.Port, "Port", "") &&
+               f(&rr.Target, "Target", "domain")
+}
+
 type dnsRR_A struct {
        Hdr dnsRR_Header
        A   uint32 `net:"ipv4"`
@@ -246,6 +338,10 @@ func (rr *dnsRR_A) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_A) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(&rr.A, "A", "ipv4")
+}
+
 type dnsRR_AAAA struct {
        Hdr  dnsRR_Header
        AAAA [16]byte `net:"ipv6"`
@@ -255,6 +351,10 @@ func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
        return &rr.Hdr
 }
 
+func (rr *dnsRR_AAAA) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return rr.Hdr.Walk(f) && f(rr.AAAA[:], "AAAA", "ipv6")
+}
+
 // Packing and unpacking.
 //
 // All the packers and unpackers take a (msg []byte, off int)
@@ -384,134 +484,107 @@ Loop:
        return s, off1, true
 }
 
-// TODO(rsc): Move into generic library?
-// Pack a reflect.StructValue into msg.  Struct members can only be uint16, uint32, string,
-// [n]byte, and other (often anonymous) structs.
-func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
-       for i := 0; i < val.NumField(); i++ {
-               f := val.Type().Field(i)
-               switch fv := val.Field(i); fv.Kind() {
+// packStruct packs a structure into msg at specified offset off, and
+// returns off1 such that msg[off:off1] is the encoded data.
+func packStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
+       ok = any.Walk(func(field interface{}, name, tag string) bool {
+               switch fv := field.(type) {
                default:
-                       println("net: dns: unknown packing type", f.Type.String())
-                       return len(msg), false
-               case reflect.Struct:
-                       off, ok = packStructValue(fv, msg, off)
-               case reflect.Uint16:
+                       println("net: dns: unknown packing type")
+                       return false
+               case *uint16:
+                       i := *fv
                        if off+2 > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       i := fv.Uint()
                        msg[off] = byte(i >> 8)
                        msg[off+1] = byte(i)
                        off += 2
-               case reflect.Uint32:
-                       if off+4 > len(msg) {
-                               return len(msg), false
-                       }
-                       i := fv.Uint()
+               case *uint32:
+                       i := *fv
                        msg[off] = byte(i >> 24)
                        msg[off+1] = byte(i >> 16)
                        msg[off+2] = byte(i >> 8)
                        msg[off+3] = byte(i)
                        off += 4
-               case reflect.Array:
-                       if fv.Type().Elem().Kind() != reflect.Uint8 {
-                               println("net: dns: unknown packing type", f.Type.String())
-                               return len(msg), false
-                       }
-                       n := fv.Len()
+               case []byte:
+                       n := len(fv)
                        if off+n > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       reflect.Copy(reflect.ValueOf(msg[off:off+n]), fv)
+                       copy(msg[off:off+n], fv)
                        off += n
-               case reflect.String:
-                       // There are multiple string encodings.
-                       // The tag distinguishes ordinary strings from domain names.
-                       s := fv.String()
-                       switch f.Tag {
+               case *string:
+                       s := *fv
+                       switch tag {
                        default:
-                               println("net: dns: unknown string tag", string(f.Tag))
-                               return len(msg), false
-                       case `net:"domain-name"`:
+                               println("net: dns: unknown string tag", tag)
+                               return false
+                       case "domain":
                                off, ok = packDomainName(s, msg, off)
                                if !ok {
-                                       return len(msg), false
+                                       return false
                                }
                        case "":
                                // Counted string: 1 byte length.
                                if len(s) > 255 || off+1+len(s) > len(msg) {
-                                       return len(msg), false
+                                       return false
                                }
                                msg[off] = byte(len(s))
                                off++
                                off += copy(msg[off:], s)
                        }
                }
+               return true
+       })
+       if !ok {
+               return len(msg), false
        }
        return off, true
 }
 
-func structValue(any interface{}) reflect.Value {
-       return reflect.ValueOf(any).Elem()
-}
-
-func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
-       off, ok = packStructValue(structValue(any), msg, off)
-       return off, ok
-}
-
-// TODO(rsc): Move into generic library?
-// Unpack a reflect.StructValue from msg.
-// Same restrictions as packStructValue.
-func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
-       for i := 0; i < val.NumField(); i++ {
-               f := val.Type().Field(i)
-               switch fv := val.Field(i); fv.Kind() {
+// unpackStruct decodes msg[off:] into the given structure, and
+// returns off1 such that msg[off:off1] is the encoded data.
+func unpackStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
+       ok = any.Walk(func(field interface{}, name, tag string) bool {
+               switch fv := field.(type) {
                default:
-                       println("net: dns: unknown packing type", f.Type.String())
-                       return len(msg), false
-               case reflect.Struct:
-                       off, ok = unpackStructValue(fv, msg, off)
-               case reflect.Uint16:
+                       println("net: dns: unknown packing type")
+                       return false
+               case *uint16:
                        if off+2 > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       i := uint16(msg[off])<<8 | uint16(msg[off+1])
-                       fv.SetUint(uint64(i))
+                       *fv = uint16(msg[off])<<8 | uint16(msg[off+1])
                        off += 2
-               case reflect.Uint32:
+               case *uint32:
                        if off+4 > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
-                       fv.SetUint(uint64(i))
+                       *fv = uint32(msg[off])<<24 | uint32(msg[off+1])<<16 |
+                               uint32(msg[off+2])<<8 | uint32(msg[off+3])
                        off += 4
-               case reflect.Array:
-                       if fv.Type().Elem().Kind() != reflect.Uint8 {
-                               println("net: dns: unknown packing type", f.Type.String())
-                               return len(msg), false
-                       }
-                       n := fv.Len()
+               case []byte:
+                       n := len(fv)
                        if off+n > len(msg) {
-                               return len(msg), false
+                               return false
                        }
-                       reflect.Copy(fv, reflect.ValueOf(msg[off:off+n]))
+                       copy(fv, msg[off:off+n])
                        off += n
-               case reflect.String:
+               case *string:
                        var s string
-                       switch f.Tag {
+                       switch tag {
                        default:
-                               println("net: dns: unknown string tag", string(f.Tag))
-                               return len(msg), false
-                       case `net:"domain-name"`:
+                               println("net: dns: unknown string tag", tag)
+                               return false
+                       case "domain":
                                s, off, ok = unpackDomainName(msg, off)
                                if !ok {
-                                       return len(msg), false
+                                       return false
                                }
                        case "":
                                if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
-                                       return len(msg), false
+                                       return false
                                }
                                n := int(msg[off])
                                off++
@@ -522,53 +595,77 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
                                off += n
                                s = string(b)
                        }
-                       fv.SetString(s)
+                       *fv = s
                }
+               return true
+       })
+       if !ok {
+               return len(msg), false
        }
        return off, true
 }
 
-func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
-       off, ok = unpackStructValue(structValue(any), msg, off)
-       return off, ok
-}
-
-// Generic struct printer.
-// Doesn't care about the string tag `net:"domain-name"`,
-// but does look for an `net:"ipv4"` tag on uint32 variables
-// and the `net:"ipv6"` tag on array variables,
-// printing them as IP addresses.
-func printStructValue(val reflect.Value) string {
+// Generic struct printer. Prints fields with tag "ipv4" or "ipv6"
+// as IP addresses.
+func printStruct(any dnsStruct) string {
        s := "{"
-       for i := 0; i < val.NumField(); i++ {
-               if i > 0 {
+       i := 0
+       any.Walk(func(val interface{}, name, tag string) bool {
+               i++
+               if i > 1 {
                        s += ", "
                }
-               f := val.Type().Field(i)
-               if !f.Anonymous {
-                       s += f.Name + "="
-               }
-               fval := val.Field(i)
-               if fv := fval; fv.Kind() == reflect.Struct {
-                       s += printStructValue(fv)
-               } else if fv := fval; (fv.Kind() == reflect.Uint || fv.Kind() == reflect.Uint8 || fv.Kind() == reflect.Uint16 || fv.Kind() == reflect.Uint32 || fv.Kind() == reflect.Uint64 || fv.Kind() == reflect.Uintptr) && f.Tag == `net:"ipv4"` {
-                       i := fv.Uint()
+               s += name + "="
+               switch tag {
+               case "ipv4":
+                       i := val.(uint32)
                        s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
-               } else if fv := fval; fv.Kind() == reflect.Array && f.Tag == `net:"ipv6"` {
-                       i := fv.Interface().([]byte)
+               case "ipv6":
+                       i := val.([]byte)
                        s += IP(i).String()
-               } else {
-                       // TODO(bradfitz,rsc): this next line panics (the String method of
-                       // *dnsMsg has been broken for awhile). Rewrite, ditch reflect.
-                       //s += fmt.Sprint(fval.Interface())
+               default:
+                       var i int64
+                       switch v := val.(type) {
+                       default:
+                               // can't really happen.
+                               s += "<unknown type>"
+                               return true
+                       case *string:
+                               s += *v
+                               return true
+                       case []byte:
+                               s += string(v)
+                               return true
+                       case *bool:
+                               if *v {
+                                       s += "true"
+                               } else {
+                                       s += "false"
+                               }
+                               return true
+                       case *int:
+                               i = int64(*v)
+                       case *uint:
+                               i = int64(*v)
+                       case *uint8:
+                               i = int64(*v)
+                       case *uint16:
+                               i = int64(*v)
+                       case *uint32:
+                               i = int64(*v)
+                       case *uint64:
+                               i = int64(*v)
+                       case *uintptr:
+                               i = int64(*v)
+                       }
+                       s += itoa(int(i))
                }
-       }
+               return true
+       })
        s += "}"
        return s
 }
 
-func printStruct(any interface{}) string { return printStructValue(structValue(any)) }
-
 // Resource record packer.
 func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
        var off1 int
@@ -627,6 +724,17 @@ type dnsMsgHdr struct {
        rcode               int
 }
 
+func (h *dnsMsgHdr) Walk(f func(v interface{}, name, tag string) bool) bool {
+       return f(&h.id, "id", "") &&
+               f(&h.response, "response", "") &&
+               f(&h.opcode, "opcode", "") &&
+               f(&h.authoritative, "authoritative", "") &&
+               f(&h.truncated, "truncated", "") &&
+               f(&h.recursion_desired, "recursion_desired", "") &&
+               f(&h.recursion_available, "recursion_available", "") &&
+               f(&h.rcode, "rcode", "")
+}
+
 type dnsMsg struct {
        dnsMsgHdr
        question []dnsQuestion
index 58f53b74197bb22f7e5a37c51b5fbbf20a824eb3..c39dbdb049d63058f8e2df942de151a366304590 100644 (file)
@@ -6,6 +6,7 @@ package net
 
 import (
        "encoding/hex"
+       "reflect"
        "testing"
 )
 
@@ -39,6 +40,16 @@ func TestDNSParseSRVReply(t *testing.T) {
                t.Errorf("len(addrs) = %d; want %d", g, e)
                t.Logf("addrs = %#v", addrs)
        }
+       // repack and unpack.
+       data2, ok := msg.Pack()
+       msg2 := new(dnsMsg)
+       msg2.Unpack(data2)
+       switch {
+       case !ok:
+               t.Errorf("failed to repack message")
+       case !reflect.DeepEqual(msg, msg2):
+               t.Errorf("repacked message differs from original")
+       }
 }
 
 func TestDNSParseCorruptSRVReply(t *testing.T) {