]> Cypherpunks repositories - gostls13.git/commitdiff
net: fix IPMask.String not to crash on all-0xff mask
authorRuss Cox <rsc@golang.org>
Tue, 16 Mar 2010 21:16:33 +0000 (14:16 -0700)
committerRuss Cox <rsc@golang.org>
Tue, 16 Mar 2010 21:16:33 +0000 (14:16 -0700)
R=r
CC=golang-dev
https://golang.org/cl/438042

src/pkg/net/ip.go

index 206e5824cfc80c2c12eb9d1a3d225352cf3f4fcd..f7ccf567e0c54268dc7754b2219bf269b28c6158 100644 (file)
@@ -51,6 +51,20 @@ func IPv4(a, b, c, d byte) IP {
        return p
 }
 
+// IPv4Mask returns the IP mask (in 16-byte form) of the
+// IPv4 mask a.b.c.d.
+func IPv4Mask(a, b, c, d byte) IPMask {
+       p := make(IPMask, IPv6len)
+       for i := 0; i < 12; i++ {
+               p[i] = 0xff
+       }
+       p[12] = a
+       p[13] = b
+       p[14] = c
+       p[15] = d
+       return p
+}
+
 // Well-known IPv4 addresses
 var (
        IPv4bcast     = IPv4(255, 255, 255, 255) // broadcast
@@ -103,9 +117,9 @@ func (ip IP) To16() IP {
 
 // Default route masks for IPv4.
 var (
-       classAMask = IPMask(IPv4(0xff, 0, 0, 0))
-       classBMask = IPMask(IPv4(0xff, 0xff, 0, 0))
-       classCMask = IPMask(IPv4(0xff, 0xff, 0xff, 0))
+       classAMask = IPv4Mask(0xff, 0, 0, 0)
+       classBMask = IPv4Mask(0xff, 0xff, 0, 0)
+       classCMask = IPv4Mask(0xff, 0xff, 0xff, 0)
 )
 
 // DefaultMask returns the default IP mask for the IP address ip.
@@ -229,25 +243,28 @@ func (ip IP) String() string {
 // If mask is a sequence of 1 bits followed by 0 bits,
 // return the number of 1 bits.
 func simpleMaskLength(mask IPMask) int {
-       var i int
-       for i = 0; i < len(mask); i++ {
-               if mask[i] != 0xFF {
-                       break
+       var n int
+       for i, v := range mask {
+               if v == 0xff {
+                       n += 8
+                       continue
                }
-       }
-       n := 8 * i
-       v := mask[i]
-       for v&0x80 != 0 {
-               n++
-               v <<= 1
-       }
-       if v != 0 {
-               return -1
-       }
-       for i++; i < len(mask); i++ {
-               if mask[i] != 0 {
+               // found non-ff byte
+               // count 1 bits
+               for v&0x80 != 0 {
+                       n++
+                       v <<= 1
+               }
+               // rest must be 0 bits
+               if v != 0 {
                        return -1
                }
+               for i++; i < len(mask); i++ {
+                       if mask[i] != 0 {
+                               return -1
+                       }
+               }
+               break
        }
        return n
 }
@@ -266,8 +283,8 @@ func (mask IPMask) String() string {
                }
        case 16:
                n := simpleMaskLength(mask)
-               if n >= 0 {
-                       return itod(uint(n))
+               if n >= 12*8 {
+                       return itod(uint(n - 12*8))
                }
        }
        return IP(mask).String()