]> Cypherpunks repositories - gostls13.git/commitdiff
net: filter destination addresses when source address is specified
authorMikio Hara <mikioh.mikioh@gmail.com>
Tue, 15 Mar 2016 01:00:12 +0000 (10:00 +0900)
committerMikio Hara <mikioh.mikioh@gmail.com>
Wed, 16 Mar 2016 03:17:56 +0000 (03:17 +0000)
This change filters out destination addresses by address family when
source address is specified to avoid running Dial operation with wrong
addressing scopes.

Fixes #11837.

Change-Id: I10b7a1fa325add2cd8ed58f105d527700a10d342
Reviewed-on: https://go-review.googlesource.com/20586
Reviewed-by: Paul Marks <pmarks@google.com>
src/net/dial.go
src/net/dial_test.go
src/net/error_test.go
src/net/ip.go
src/net/ipsock.go
src/net/net.go

index e4e44d226361221dccd5ca5fc6c58a2486cc4467..22992d5b7a95a66f0f4f5af315fc776cb8b2a1f7 100644 (file)
@@ -5,7 +5,6 @@
 package net
 
 import (
-       "errors"
        "runtime"
        "time"
 )
@@ -140,8 +139,11 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
        return "", 0, UnknownNetworkError(net)
 }
 
-func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) {
-       afnet, _, err := parseNetwork(net)
+// resolverAddrList resolves addr using hint and returns a list of
+// addresses. The result contains at least one address when error is
+// nil.
+func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (addrList, error) {
+       afnet, _, err := parseNetwork(network)
        if err != nil {
                return nil, err
        }
@@ -154,9 +156,59 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error)
                if err != nil {
                        return nil, err
                }
+               if op == "dial" && hint != nil && addr.Network() != hint.Network() {
+                       return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
+               }
                return addrList{addr}, nil
        }
-       return internetAddrList(afnet, addr, deadline)
+       addrs, err := internetAddrList(afnet, addr, deadline)
+       if err != nil || op != "dial" || hint == nil {
+               return addrs, err
+       }
+       var (
+               tcp      *TCPAddr
+               udp      *UDPAddr
+               ip       *IPAddr
+               wildcard bool
+       )
+       switch hint := hint.(type) {
+       case *TCPAddr:
+               tcp = hint
+               wildcard = tcp.isWildcard()
+       case *UDPAddr:
+               udp = hint
+               wildcard = udp.isWildcard()
+       case *IPAddr:
+               ip = hint
+               wildcard = ip.isWildcard()
+       }
+       naddrs := addrs[:0]
+       for _, addr := range addrs {
+               if addr.Network() != hint.Network() {
+                       return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
+               }
+               switch addr := addr.(type) {
+               case *TCPAddr:
+                       if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
+                               continue
+                       }
+                       naddrs = append(naddrs, addr)
+               case *UDPAddr:
+                       if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
+                               continue
+                       }
+                       naddrs = append(naddrs, addr)
+               case *IPAddr:
+                       if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
+                               continue
+                       }
+                       naddrs = append(naddrs, addr)
+               }
+       }
+       if len(naddrs) == 0 {
+               return nil, errNoSuitableAddress
+       }
+       return naddrs, nil
 }
 
 // Dial connects to the address on the named network.
@@ -214,7 +266,7 @@ type dialContext struct {
 // parameters.
 func (d *Dialer) Dial(network, address string) (Conn, error) {
        finalDeadline := d.deadline(time.Now())
-       addrs, err := resolveAddrList("dial", network, address, finalDeadline)
+       addrs, err := resolveAddrList("dial", network, address, d.LocalAddr, finalDeadline)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
        }
@@ -387,9 +439,6 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e
 // dial function, because some OSes don't implement the deadline feature.
 func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) {
        la := ctx.LocalAddr
-       if la != nil && la.Network() != ra.Network() {
-               return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())}
-       }
        switch ra := ra.(type) {
        case *TCPAddr:
                la, _ := la.(*TCPAddr)
@@ -420,7 +469,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan str
 // instead of just the interface with the given host address.
 // See Dial for more details about address syntax.
 func Listen(net, laddr string) (Listener, error) {
-       addrs, err := resolveAddrList("listen", net, laddr, noDeadline)
+       addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
        }
@@ -447,7 +496,7 @@ func Listen(net, laddr string) (Listener, error) {
 // instead of just the interface with the given host address.
 // See Dial for the syntax of laddr.
 func ListenPacket(net, laddr string) (PacketConn, error) {
-       addrs, err := resolveAddrList("listen", net, laddr, noDeadline)
+       addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
        if err != nil {
                return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
        }
index 5fe3e856f89f0d877c985e29589147f9ba1c901a..3335df5a93c74595a2e3f9c7ddb6932206f08b36 100644 (file)
@@ -646,41 +646,118 @@ func TestDialerPartialDeadline(t *testing.T) {
        }
 }
 
+type dialerLocalAddrTest struct {
+       network, raddr string
+       laddr          Addr
+       error
+}
+
+var dialerLocalAddrTests = []dialerLocalAddrTest{
+       {"tcp4", "127.0.0.1", nil, nil},
+       {"tcp4", "127.0.0.1", &TCPAddr{}, nil},
+       {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+       {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+       {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"}},
+       {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil},
+       {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil},
+       {"tcp4", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress},
+       {"tcp4", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}},
+       {"tcp4", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}},
+
+       {"tcp6", "::1", nil, nil},
+       {"tcp6", "::1", &TCPAddr{}, nil},
+       {"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+       {"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+       {"tcp6", "::1", &TCPAddr{IP: ParseIP("::")}, nil},
+       {"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress},
+       {"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress},
+       {"tcp6", "::1", &TCPAddr{IP: IPv6loopback}, nil},
+       {"tcp6", "::1", &UDPAddr{}, &AddrError{Err: "some error"}},
+       {"tcp6", "::1", &UnixAddr{}, &AddrError{Err: "some error"}},
+
+       {"tcp", "127.0.0.1", nil, nil},
+       {"tcp", "127.0.0.1", &TCPAddr{}, nil},
+       {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+       {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+       {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil},
+       {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil},
+       {"tcp", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress},
+       {"tcp", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}},
+       {"tcp", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}},
+
+       {"tcp", "::1", nil, nil},
+       {"tcp", "::1", &TCPAddr{}, nil},
+       {"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+       {"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+       {"tcp", "::1", &TCPAddr{IP: ParseIP("::")}, nil},
+       {"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress},
+       {"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress},
+       {"tcp", "::1", &TCPAddr{IP: IPv6loopback}, nil},
+       {"tcp", "::1", &UDPAddr{}, &AddrError{Err: "some error"}},
+       {"tcp", "::1", &UnixAddr{}, &AddrError{Err: "some error"}},
+}
+
 func TestDialerLocalAddr(t *testing.T) {
-       ch := make(chan error, 1)
-       handler := func(ls *localServer, ln Listener) {
-               c, err := ln.Accept()
-               if err != nil {
-                       ch <- err
-                       return
-               }
-               defer c.Close()
-               ch <- nil
-       }
-       ls, err := newLocalServer("tcp")
-       if err != nil {
-               t.Fatal(err)
+       if !supportsIPv4 || !supportsIPv6 {
+               t.Skip("both IPv4 and IPv6 are required")
        }
-       defer ls.teardown()
-       if err := ls.buildup(handler); err != nil {
-               t.Fatal(err)
+
+       if supportsIPv4map {
+               dialerLocalAddrTests = append(dialerLocalAddrTests, dialerLocalAddrTest{
+                       "tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, nil,
+               })
+       } else {
+               dialerLocalAddrTests = append(dialerLocalAddrTests, dialerLocalAddrTest{
+                       "tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"},
+               })
        }
 
-       laddr, err := ResolveTCPAddr(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
-       if err != nil {
-               t.Fatal(err)
+       origTestHookLookupIP := testHookLookupIP
+       defer func() { testHookLookupIP = origTestHookLookupIP }()
+       testHookLookupIP = lookupLocalhost
+       handler := func(ls *localServer, ln Listener) {
+               for {
+                       c, err := ln.Accept()
+                       if err != nil {
+                               return
+                       }
+                       c.Close()
+               }
        }
-       laddr.Port = 0
-       d := &Dialer{LocalAddr: laddr}
-       c, err := d.Dial(ls.Listener.Addr().Network(), ls.Addr().String())
-       if err != nil {
-               t.Fatal(err)
+       var err error
+       var lss [2]*localServer
+       for i, network := range []string{"tcp4", "tcp6"} {
+               lss[i], err = newLocalServer(network)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer lss[i].teardown()
+               if err := lss[i].buildup(handler); err != nil {
+                       t.Fatal(err)
+               }
        }
-       defer c.Close()
-       c.Read(make([]byte, 1))
-       err = <-ch
-       if err != nil {
-               t.Error(err)
+
+       for _, tt := range dialerLocalAddrTests {
+               d := &Dialer{LocalAddr: tt.laddr}
+               var addr string
+               ip := ParseIP(tt.raddr)
+               if ip.To4() != nil {
+                       addr = lss[0].Listener.Addr().String()
+               }
+               if ip.To16() != nil && ip.To4() == nil {
+                       addr = lss[1].Listener.Addr().String()
+               }
+               c, err := d.Dial(tt.network, addr)
+               if err == nil && tt.error != nil || err != nil && tt.error == nil {
+                       t.Errorf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error)
+               }
+               if err != nil {
+                       if perr := parseDialError(err); perr != nil {
+                               t.Error(perr)
+                       }
+                       continue
+               }
+               c.Close()
        }
 }
 
index ee0979c74891e9a252d09266cfff0088cb52eb39..c3a4d32382ade9366d998dcc713012fde7dd9035 100644 (file)
@@ -96,7 +96,7 @@ second:
                goto third
        }
        switch nestedErr {
-       case errCanceled, errClosing, errMissingAddress:
+       case errCanceled, errClosing, errMissingAddress, errNoSuitableAddress:
                return nil
        }
        return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
@@ -416,7 +416,7 @@ second:
                goto third
        }
        switch nestedErr {
-       case errCanceled, errClosing, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF:
+       case errCanceled, errClosing, errMissingAddress, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF:
                return nil
        }
        return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
index a25729cfc9e94b8b51269dd48420194f37e4b34c..0501f5a6a324e5f83d8781bc9f8969cd499849a1 100644 (file)
@@ -377,6 +377,10 @@ func bytesEqual(x, y []byte) bool {
        return true
 }
 
+func (ip IP) matchAddrFamily(x IP) bool {
+       return ip.To4() != nil && x.To4() != nil || ip.To16() != nil && ip.To4() == nil && x.To16() != nil && x.To4() == nil
+}
+
 // If mask is a sequence of 1 bits followed by 0 bits,
 // return the number of 1 bits.
 func simpleMaskLength(mask IPMask) int {
index f3ac00df052dfdb24a7c934c304e3bf9199617c6..f093b4926d7ac13b374bb35091dd8a3aec01cf4b 100644 (file)
@@ -6,10 +6,7 @@
 
 package net
 
-import (
-       "errors"
-       "time"
-)
+import "time"
 
 var (
        // supportsIPv4 reports whether the platform supports IPv4
@@ -73,8 +70,6 @@ func (addrs addrList) partition(strategy func(Addr) bool) (primaries, fallbacks
        return
 }
 
-var errNoSuitableAddress = errors.New("no suitable address found")
-
 // filterAddrList applies a filter to a list of IP addresses,
 // yielding a list of Addr objects. Known filters are nil, ipv4only,
 // and ipv6only. It returns every address when the filter is nil.
index 2ff1a34981a5dd6b2e68ee6073fcfba4af202dcb..3b37b336d1bdcffe64574537b325ff4326a1373b 100644 (file)
@@ -364,6 +364,9 @@ type Error interface {
 
 // Various errors contained in OpError.
 var (
+       // For connection setup operations.
+       errNoSuitableAddress = errors.New("no suitable address found")
+
        // For connection setup and write operations.
        errMissingAddress = errors.New("missing address")