]> Cypherpunks repositories - gostls13.git/commitdiff
net: make cgo resolver work more accurately with network parameter
authorEugene Kalinin <e.v.kalinin@gmail.com>
Wed, 20 Jun 2018 22:23:37 +0000 (01:23 +0300)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 25 Oct 2018 03:14:03 +0000 (03:14 +0000)
Unlike the go resolver, the existing cgo resolver exchanges both DNS A
and AAAA RR queries unconditionally and causes unreasonable connection
setup latencies to applications using the cgo resolver.

This change adds new argument (`network`) in all functions through the
series of calls: from Resolver.internetAddrList to cgoLookupIPCNAME.

Benefit: no redundant DNS calls if certain IP version is used IPv4/IPv6
(no `AAAA` DNS requests if used tcp4, udp4, ip4 network. And vice
versa: no `A` DNS requests if used tcp6, udp6, ip6 network)

Fixes #25947

Change-Id: I39edbd726d82d6133fdada4d06cd90d401e7e669
Reviewed-on: https://go-review.googlesource.com/c/120215
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
15 files changed:
src/net/cgo_stub.go
src/net/cgo_unix.go
src/net/cgo_unix_test.go
src/net/dial_test.go
src/net/error_test.go
src/net/hook.go
src/net/http/transport_test.go
src/net/ipsock.go
src/net/lookup.go
src/net/lookup_fake.go
src/net/lookup_plan9.go
src/net/lookup_test.go
src/net/lookup_unix.go
src/net/lookup_windows.go
src/net/netgo_unix_test.go

index 51259722aec8e6ce54d0d8e2ed46457afa02b3e4..041f8af12966a4363c3c4822061f7b956e478b40 100644 (file)
@@ -24,7 +24,7 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err
        return 0, nil, false
 }
 
-func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) {
+func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error, completed bool) {
        return nil, nil, false
 }
 
index 3db867a080e1fafdcd17005ece17a6b525189bb5..b7cbcfe77a9d36f82530264cddcf24ee31f860fc 100644 (file)
@@ -49,7 +49,7 @@ type reverseLookupResult struct {
 }
 
 func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error, completed bool) {
-       addrs, err, completed := cgoLookupIP(ctx, name)
+       addrs, err, completed := cgoLookupIP(ctx, "ip", name)
        for _, addr := range addrs {
                hosts = append(hosts, addr.String())
        }
@@ -69,13 +69,11 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err
        default:
                return 0, &DNSError{Err: "unknown network", Name: network + "/" + service}, true
        }
-       if len(network) >= 4 {
-               switch network[3] {
-               case '4':
-                       hints.ai_family = C.AF_INET
-               case '6':
-                       hints.ai_family = C.AF_INET6
-               }
+       switch ipVersion(network) {
+       case '4':
+               hints.ai_family = C.AF_INET
+       case '6':
+               hints.ai_family = C.AF_INET6
        }
        if ctx.Done() == nil {
                port, err := cgoLookupServicePort(&hints, network, service)
@@ -135,13 +133,20 @@ func cgoPortLookup(result chan<- portLookupResult, hints *C.struct_addrinfo, net
        result <- portLookupResult{port, err}
 }
 
-func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error) {
+func cgoLookupIPCNAME(network, name string) (addrs []IPAddr, cname string, err error) {
        acquireThread()
        defer releaseThread()
 
        var hints C.struct_addrinfo
        hints.ai_flags = cgoAddrInfoFlags
        hints.ai_socktype = C.SOCK_STREAM
+       hints.ai_family = C.AF_UNSPEC
+       switch ipVersion(network) {
+       case '4':
+               hints.ai_family = C.AF_INET
+       case '6':
+               hints.ai_family = C.AF_INET6
+       }
 
        h := make([]byte, len(name)+1)
        copy(h, name)
@@ -197,18 +202,18 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error) {
        return addrs, cname, nil
 }
 
-func cgoIPLookup(result chan<- ipLookupResult, name string) {
-       addrs, cname, err := cgoLookupIPCNAME(name)
+func cgoIPLookup(result chan<- ipLookupResult, network, name string) {
+       addrs, cname, err := cgoLookupIPCNAME(network, name)
        result <- ipLookupResult{addrs, cname, err}
 }
 
-func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) {
+func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error, completed bool) {
        if ctx.Done() == nil {
-               addrs, _, err = cgoLookupIPCNAME(name)
+               addrs, _, err = cgoLookupIPCNAME(network, name)
                return addrs, err, true
        }
        result := make(chan ipLookupResult, 1)
-       go cgoIPLookup(result, name)
+       go cgoIPLookup(result, network, name)
        select {
        case r := <-result:
                return r.addrs, r.err, true
@@ -219,11 +224,11 @@ func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, c
 
 func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) {
        if ctx.Done() == nil {
-               _, cname, err = cgoLookupIPCNAME(name)
+               _, cname, err = cgoLookupIPCNAME("ip", name)
                return cname, err, true
        }
        result := make(chan ipLookupResult, 1)
-       go cgoIPLookup(result, name)
+       go cgoIPLookup(result, "ip", name)
        select {
        case r := <-result:
                return r.cname, r.err, true
index b476a6d62686ea5e15dae052871d4106de894a42..c3eab5b3b2a5e203b6f7921832498d2d547442fe 100644 (file)
@@ -15,7 +15,7 @@ import (
 func TestCgoLookupIP(t *testing.T) {
        defer dnsWaitGroup.Wait()
        ctx := context.Background()
-       _, err, ok := cgoLookupIP(ctx, "localhost")
+       _, err, ok := cgoLookupIP(ctx, "ip", "localhost")
        if !ok {
                t.Errorf("cgoLookupIP must not be a placeholder")
        }
@@ -28,7 +28,7 @@ func TestCgoLookupIPWithCancel(t *testing.T) {
        defer dnsWaitGroup.Wait()
        ctx, cancel := context.WithCancel(context.Background())
        defer cancel()
-       _, err, ok := cgoLookupIP(ctx, "localhost")
+       _, err, ok := cgoLookupIP(ctx, "ip", "localhost")
        if !ok {
                t.Errorf("cgoLookupIP must not be a placeholder")
        }
index 3a45c0d2ec9a757a0beee66f3f76e41b9eabb360..983338885dac52df8c0fbc93f804e67bf6e57f3e 100644 (file)
@@ -346,7 +346,7 @@ func TestDialParallel(t *testing.T) {
        }
 }
 
-func lookupSlowFast(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+func lookupSlowFast(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
        switch host {
        case "slow6loopback4":
                // Returns a slow IPv6 address, and a local IPv4 address.
@@ -355,7 +355,7 @@ func lookupSlowFast(ctx context.Context, fn func(context.Context, string) ([]IPA
                        {IP: ParseIP("127.0.0.1")},
                }, nil
        default:
-               return fn(ctx, host)
+               return fn(ctx, network, host)
        }
 }
 
index e09670e5ce0e2f8e126107bd3b95422877231696..2819986c0cd7cefb20a759050afb57666370a771 100644 (file)
@@ -144,7 +144,7 @@ func TestDialError(t *testing.T) {
 
        origTestHookLookupIP := testHookLookupIP
        defer func() { testHookLookupIP = origTestHookLookupIP }()
-       testHookLookupIP = func(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+       testHookLookupIP = func(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
                return nil, &DNSError{Err: "dial error test", Name: "name", Server: "server", IsTimeout: true}
        }
        sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
@@ -293,7 +293,7 @@ func TestListenError(t *testing.T) {
 
        origTestHookLookupIP := testHookLookupIP
        defer func() { testHookLookupIP = origTestHookLookupIP }()
-       testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+       testHookLookupIP = func(_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
                return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
        }
        sw.Set(socktest.FilterListen, func(so *socktest.Status) (socktest.AfterFilter, error) {
@@ -353,7 +353,7 @@ func TestListenPacketError(t *testing.T) {
 
        origTestHookLookupIP := testHookLookupIP
        defer func() { testHookLookupIP = origTestHookLookupIP }()
-       testHookLookupIP = func(_ context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+       testHookLookupIP = func(_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
                return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
        }
 
index d7316ea4383f5c359ee4e05cf25158719dce8a17..5a1156378b569b98e571e819efb1d1c7accc7e00 100644 (file)
@@ -13,10 +13,11 @@ var (
        testHookHostsPath = "/etc/hosts"
        testHookLookupIP  = func(
                ctx context.Context,
-               fn func(context.Context, string) ([]IPAddr, error),
+               fn func(context.Context, string, string) ([]IPAddr, error),
+               network string,
                host string,
        ) ([]IPAddr, error) {
-               return fn(ctx, host)
+               return fn(ctx, network, host)
        }
        testHookSetKeepAlive = func() {}
 )
index 211f8cb467d5caed88d07a5aa8e60b53e00631d8..3f9750392c61d0680fcae6a4b879a8e2de976796 100644 (file)
@@ -3825,9 +3825,9 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
        }
 
        // Install a fake DNS server.
-       ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
+       ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
                if host != "dns-is-faked.golang" {
-                       t.Errorf("unexpected DNS host lookup for %q", host)
+                       t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
                        return nil, nil
                }
                return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
@@ -4176,7 +4176,7 @@ func TestTransportMaxIdleConns(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
+       ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
                return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
        })
 
@@ -4416,9 +4416,9 @@ func testTransportIDNA(t *testing.T, h2 bool) {
        }
 
        // Install a fake DNS server.
-       ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
+       ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
                if host != punyDomain {
-                       t.Errorf("got DNS host lookup for %q; want %q", host, punyDomain)
+                       t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
                        return nil, nil
                }
                return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
index 84fa0ac0a3270e6566bcdac3c079aa8d55a6d94a..7d0684d176400d0a1c2bd23a0a8b460a75bc8ff5 100644 (file)
@@ -277,7 +277,7 @@ func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addr
        }
 
        // Try as a literal IP address, then as a DNS name.
-       ips, err := r.LookupIPAddr(ctx, host)
+       ips, err := r.lookupIPAddr(ctx, net, host)
        if err != nil {
                return nil, err
        }
index e0f21fa9a8d3fd580e6da1d37e16dd75c31b5a97..cb810dea267bd15d6e95c9b355270281af6f7c07 100644 (file)
@@ -97,6 +97,19 @@ func lookupPortMap(network, service string) (port int, error error) {
        return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
 }
 
+// ipVersion returns the provided network's IP version: '4', '6' or 0
+// if network does not end in a '4' or '6' byte.
+func ipVersion(network string) byte {
+       if network == "" {
+               return 0
+       }
+       n := network[len(network)-1]
+       if n != '4' && n != '6' {
+               n = 0
+       }
+       return n
+}
+
 // DefaultResolver is the resolver used by the package-level Lookup
 // functions and by Dialers without a specified Resolver.
 var DefaultResolver = &Resolver{}
@@ -189,6 +202,12 @@ func LookupIP(host string) ([]IP, error) {
 // LookupIPAddr looks up host using the local resolver.
 // It returns a slice of that host's IPv4 and IPv6 addresses.
 func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) {
+       return r.lookupIPAddr(ctx, "ip", host)
+}
+
+// lookupIPAddr looks up host using the local resolver and particular network.
+// It returns a slice of that host's IPv4 and IPv6 addresses.
+func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) {
        // Make sure that no matter what we do later, host=="" is rejected.
        // parseIP, for example, does accept empty strings.
        if host == "" {
@@ -205,7 +224,7 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
        // can be overridden by tests. This is needed by net/http, so it
        // uses a context key instead of unexported variables.
        resolverFunc := r.lookupIP
-       if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string) ([]IPAddr, error)); alt != nil {
+       if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string, string) ([]IPAddr, error)); alt != nil {
                resolverFunc = alt
        }
 
@@ -218,7 +237,7 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
        dnsWaitGroup.Add(1)
        ch, called := r.getLookupGroup().DoChan(host, func() (interface{}, error) {
                defer dnsWaitGroup.Done()
-               return testHookLookupIP(lookupGroupCtx, resolverFunc, host)
+               return testHookLookupIP(lookupGroupCtx, resolverFunc, network, host)
        })
        if !called {
                dnsWaitGroup.Done()
@@ -289,6 +308,13 @@ func LookupPort(network, service string) (port int, err error) {
 func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
        port, needsLookup := parsePort(service)
        if needsLookup {
+               switch network {
+               case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
+               case "": // a hint wildcard for Go 1.0 undocumented behavior
+                       network = "ip"
+               default:
+                       return 0, &AddrError{Err: "unknown network", Addr: network}
+               }
                port, err = r.lookupPort(ctx, network, service)
                if err != nil {
                        return 0, err
index d3d1dbc90032f743641eaad1dc56966a8c30a9ba..6c8a151bcac1c89a4eca507cdb15926909b09581 100644 (file)
@@ -19,7 +19,7 @@ func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, e
        return nil, syscall.ENOPROTOOPT
 }
 
-func (*Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+func (*Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
        return nil, syscall.ENOPROTOOPT
 }
 
index d5ae9b2fd9d98be9dd6567fda9dd680a0afe98ea..70805ddf4cd05a2db388532ea2022b86fce3318c 100644 (file)
@@ -176,7 +176,7 @@ loop:
        return
 }
 
-func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+func (r *Resolver) lookupIP(ctx context.Context, _, host string) (addrs []IPAddr, err error) {
        lits, err := r.lookupHost(ctx, host)
        if err != nil {
                return
index 5c66dfa2603ad902c3b0a608277dd817a76d974f..aeeda8f7d0eb74f7bbf8f57c6bd9e91522821003 100644 (file)
@@ -20,7 +20,7 @@ import (
        "time"
 )
 
-func lookupLocalhost(ctx context.Context, fn func(context.Context, string) ([]IPAddr, error), host string) ([]IPAddr, error) {
+func lookupLocalhost(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
        switch host {
        case "localhost":
                return []IPAddr{
@@ -28,7 +28,7 @@ func lookupLocalhost(ctx context.Context, fn func(context.Context, string) ([]IP
                        {IP: IPv6loopback},
                }, nil
        default:
-               return fn(ctx, host)
+               return fn(ctx, network, host)
        }
 }
 
@@ -1008,3 +1008,29 @@ func TestConcurrentPreferGoResolversDial(t *testing.T) {
                }
        }
 }
+
+var ipVersionTests = []struct {
+       network string
+       version byte
+}{
+       {"tcp", 0},
+       {"tcp4", '4'},
+       {"tcp6", '6'},
+       {"udp", 0},
+       {"udp4", '4'},
+       {"udp6", '6'},
+       {"ip", 0},
+       {"ip4", '4'},
+       {"ip6", '6'},
+       {"ip7", 0},
+       {"", 0},
+}
+
+func TestIPVersion(t *testing.T) {
+       for _, tt := range ipVersionTests {
+               if version := ipVersion(tt.network); version != tt.version {
+                       t.Errorf("Family for: %s. Expected: %s, Got: %s", tt.network,
+                               string(tt.version), string(version))
+               }
+       }
+}
index e8e7a9bf5a2913a6085c7806b2c26a7f9ffb3f30..bef9dcfe14634b43aeb6d76e98e7fa6bfe299dfb 100644 (file)
@@ -87,13 +87,13 @@ func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string,
        return r.goLookupHostOrder(ctx, host, order)
 }
 
-func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
        if r.preferGo() {
                return r.goLookupIP(ctx, host)
        }
        order := systemConf().hostLookupOrder(r, host)
        if order == hostLookupCgo {
-               if addrs, err, ok := cgoLookupIP(ctx, host); ok {
+               if addrs, err, ok := cgoLookupIP(ctx, network, host); ok {
                        return addrs, err
                }
                // cgo not available (or netgo); fall back to Go's DNS resolver
index f76e0af400dcc87ffeba13678f2b3deed3a4d637..8a68d18a674b3ff8f6fc611431b89e2c5ac1e8f7 100644 (file)
@@ -65,7 +65,7 @@ func lookupProtocol(ctx context.Context, name string) (int, error) {
 }
 
 func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error) {
-       ips, err := r.lookupIP(ctx, name)
+       ips, err := r.lookupIP(ctx, "ip", name)
        if err != nil {
                return nil, err
        }
@@ -76,14 +76,22 @@ func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error
        return addrs, nil
 }
 
-func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
+func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) {
        // TODO(bradfitz,brainman): use ctx more. See TODO below.
 
+       var family int32 = syscall.AF_UNSPEC
+       switch ipVersion(network) {
+       case '4':
+               family = syscall.AF_INET
+       case '6':
+               family = syscall.AF_INET6
+       }
+
        getaddr := func() ([]IPAddr, error) {
                acquireThread()
                defer releaseThread()
                hints := syscall.AddrinfoW{
-                       Family:   syscall.AF_UNSPEC,
+                       Family:   family,
                        Socktype: syscall.SOCK_STREAM,
                        Protocol: syscall.IPPROTO_IP,
                }
index f2244ea69c40ba53822ecd7b7f49e77580368050..c672d3e8ebfbedf3760a7af20ceb456f5669f998 100644 (file)
@@ -16,7 +16,7 @@ func TestGoLookupIP(t *testing.T) {
        defer dnsWaitGroup.Wait()
        host := "localhost"
        ctx := context.Background()
-       _, err, ok := cgoLookupIP(ctx, host)
+       _, err, ok := cgoLookupIP(ctx, "ip", host)
        if ok {
                t.Errorf("cgoLookupIP must be a placeholder")
        }