]> Cypherpunks repositories - gostls13.git/commitdiff
net: validate network in Dial{,IP} and Listen{Packet,IP} for IP networks
authorMikio Hara <mikioh.mikioh@gmail.com>
Thu, 13 Apr 2017 09:00:35 +0000 (18:00 +0900)
committerMikio Hara <mikioh.mikioh@gmail.com>
Fri, 14 Apr 2017 08:44:45 +0000 (08:44 +0000)
The argument of the first parameter for connection setup functions on
IP networks must contain a protocol name or number. This change adds
validation for arguments of IP networks to connection setup functions.

Fixes #18185.

Change-Id: I6aaedd7806e3ed1043d4b1c834024f350b99361d
Reviewed-on: https://go-review.googlesource.com/40512
Run-TryBot: Mikio Hara <mikioh.mikioh@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/net/dial.go
src/net/iprawsock.go
src/net/iprawsock_posix.go
src/net/iprawsock_test.go

index 0a7da408fe526a567a60f208c82132bc5973d88d..bed0b1e3e0c5e982fc19f97b531a6799f36693df 100644 (file)
@@ -135,23 +135,26 @@ func (d *Dialer) fallbackDelay() time.Duration {
        }
 }
 
-func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) {
-       i := last(net, ':')
+func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
+       i := last(network, ':')
        if i < 0 { // no colon
-               switch net {
+               switch network {
                case "tcp", "tcp4", "tcp6":
                case "udp", "udp4", "udp6":
                case "ip", "ip4", "ip6":
+                       if needsProto {
+                               return "", 0, UnknownNetworkError(network)
+                       }
                case "unix", "unixgram", "unixpacket":
                default:
-                       return "", 0, UnknownNetworkError(net)
+                       return "", 0, UnknownNetworkError(network)
                }
-               return net, 0, nil
+               return network, 0, nil
        }
-       afnet = net[:i]
+       afnet = network[:i]
        switch afnet {
        case "ip", "ip4", "ip6":
-               protostr := net[i+1:]
+               protostr := network[i+1:]
                proto, i, ok := dtoi(protostr)
                if !ok || i != len(protostr) {
                        proto, err = lookupProtocol(ctx, protostr)
@@ -161,14 +164,14 @@ func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err
                }
                return afnet, proto, nil
        }
-       return "", 0, UnknownNetworkError(net)
+       return "", 0, UnknownNetworkError(network)
 }
 
 // resolveAddrList resolves addr using hint and returns a list of
 // addresses. The result contains at least one address when error is
 // nil.
 func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
-       afnet, _, err := parseNetwork(ctx, network)
+       afnet, _, err := parseNetwork(ctx, network, true)
        if err != nil {
                return nil, err
        }
index d994fc67c65825f221f4b42dc25723b56159f953..d69a303d786b7784f9209cac12e6665894e4f91b 100644 (file)
@@ -71,7 +71,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) {
        if net == "" { // a hint wildcard for Go 1.0 undocumented behavior
                net = "ip"
        }
-       afnet, _, err := parseNetwork(context.Background(), net)
+       afnet, _, err := parseNetwork(context.Background(), net, false)
        if err != nil {
                return nil, err
        }
index 8f4b702e480624c340ac96714abab84f15c0d707..5d76818af9f70121a2e2e72a1b0fbb469b17d7ad 100644 (file)
@@ -113,7 +113,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
 }
 
 func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
-       network, proto, err := parseNetwork(ctx, netProto)
+       network, proto, err := parseNetwork(ctx, netProto, true)
        if err != nil {
                return nil, err
        }
@@ -133,7 +133,7 @@ func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn
 }
 
 func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
-       network, proto, err := parseNetwork(ctx, netProto)
+       network, proto, err := parseNetwork(ctx, netProto, true)
        if err != nil {
                return nil, err
        }
index 5d33b26a91e0291b54ba232a176e6eaae6e28519..8972051f5d5e21f58c2f56fa5aaf1e2fab78f901 100644 (file)
@@ -117,3 +117,75 @@ func TestIPConnRemoteName(t *testing.T) {
                t.Fatalf("got %#v; want %#v", c.RemoteAddr(), raddr)
        }
 }
+
+func TestDialListenIPArgs(t *testing.T) {
+       type test struct {
+               argLists   [][2]string
+               shouldFail bool
+       }
+       tests := []test{
+               {
+                       argLists: [][2]string{
+                               {"ip", "127.0.0.1"},
+                               {"ip:", "127.0.0.1"},
+                               {"ip::", "127.0.0.1"},
+                               {"ip", "::1"},
+                               {"ip:", "::1"},
+                               {"ip::", "::1"},
+                               {"ip4", "127.0.0.1"},
+                               {"ip4:", "127.0.0.1"},
+                               {"ip4::", "127.0.0.1"},
+                               {"ip6", "::1"},
+                               {"ip6:", "::1"},
+                               {"ip6::", "::1"},
+                       },
+                       shouldFail: true,
+               },
+       }
+       if testableNetwork("ip") {
+               priv := test{shouldFail: false}
+               for _, tt := range []struct {
+                       network, address string
+                       args             [2]string
+               }{
+                       {"ip4:47", "127.0.0.1", [2]string{"ip4:47", "127.0.0.1"}},
+                       {"ip6:47", "::1", [2]string{"ip6:47", "::1"}},
+               } {
+                       c, err := ListenPacket(tt.network, tt.address)
+                       if err != nil {
+                               continue
+                       }
+                       c.Close()
+                       priv.argLists = append(priv.argLists, tt.args)
+               }
+               if len(priv.argLists) > 0 {
+                       tests = append(tests, priv)
+               }
+       }
+
+       for _, tt := range tests {
+               for _, args := range tt.argLists {
+                       _, err := Dial(args[0], args[1])
+                       if tt.shouldFail != (err != nil) {
+                               t.Errorf("Dial(%q, %q) = %v; want (err != nil) is %t", args[0], args[1], err, tt.shouldFail)
+                       }
+                       _, err = ListenPacket(args[0], args[1])
+                       if tt.shouldFail != (err != nil) {
+                               t.Errorf("ListenPacket(%q, %q) = %v; want (err != nil) is %t", args[0], args[1], err, tt.shouldFail)
+                       }
+                       a, err := ResolveIPAddr("ip", args[1])
+                       if err != nil {
+                               t.Errorf("ResolveIPAddr(\"ip\", %q) = %v", args[1], err)
+                               continue
+                       }
+                       _, err = DialIP(args[0], nil, a)
+                       if tt.shouldFail != (err != nil) {
+                               t.Errorf("DialIP(%q, %v) = %v; want (err != nil) is %t", args[0], a, err, tt.shouldFail)
+                       }
+                       _, err = ListenIP(args[0], a)
+                       if tt.shouldFail != (err != nil) {
+                               t.Errorf("ListenIP(%q, %v) = %v; want (err != nil) is %t", args[0], a, err, tt.shouldFail)
+                       }
+               }
+       }
+}