]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/base64: reduce the overflow risk when computing encode/decode length
authorchanxuehong <chanxuehong@gmail.com>
Thu, 20 Jul 2023 04:43:38 +0000 (04:43 +0000)
committerGopher Robot <gobot@golang.org>
Fri, 21 Jul 2023 14:44:47 +0000 (14:44 +0000)
Change-Id: I0a55cdc38ae496e2070f0b9ef317a41f82352afd
GitHub-Last-Rev: c19527a26b0778cbb4548f49e1e365102709f068
GitHub-Pull-Request: golang/go#61407
Reviewed-on: https://go-review.googlesource.com/c/go/+/510635
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
src/encoding/base64/base64.go
src/encoding/base64/base64_test.go

index 6aa8a15bdc8e21614d0ab4e25d65c27529dc7fbd..87f68970627baa477af5c01cd9e4eef750e02445 100644 (file)
@@ -278,7 +278,7 @@ func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
 // of an input buffer of length n.
 func (enc *Encoding) EncodedLen(n int) int {
        if enc.padChar == NoPadding {
-               return (n*8 + 5) / 6 // minimum # chars at 6 bits per char
+               return n/3*4 + (n%3*8+5)/6 // minimum # chars at 6 bits per char
        }
        return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
 }
@@ -623,7 +623,7 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
 func (enc *Encoding) DecodedLen(n int) int {
        if enc.padChar == NoPadding {
                // Unpadded data may end with partial block of 2-3 characters.
-               return n * 6 / 8
+               return n/4*3 + n%4*6/8
        }
        // Padded base64 should always be a multiple of 4 characters in length.
        return n / 4 * 3
index 0ad88ebb3a74186f83594395946f0427c60304f5..97aea845ae3ece612004c8ce80e0dd9f7aba7b31 100644 (file)
@@ -9,8 +9,10 @@ import (
        "errors"
        "fmt"
        "io"
+       "math"
        "reflect"
        "runtime/debug"
+       "strconv"
        "strings"
        "testing"
        "time"
@@ -262,11 +264,12 @@ func TestDecodeBounds(t *testing.T) {
 }
 
 func TestEncodedLen(t *testing.T) {
-       for _, tt := range []struct {
+       type test struct {
                enc  *Encoding
                n    int
-               want int
-       }{
+               want int64
+       }
+       tests := []test{
                {RawStdEncoding, 0, 0},
                {RawStdEncoding, 1, 2},
                {RawStdEncoding, 2, 3},
@@ -278,19 +281,30 @@ func TestEncodedLen(t *testing.T) {
                {StdEncoding, 3, 4},
                {StdEncoding, 4, 8},
                {StdEncoding, 7, 12},
-       } {
-               if got := tt.enc.EncodedLen(tt.n); got != tt.want {
+       }
+       // check overflow
+       switch strconv.IntSize {
+       case 32:
+               tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 357913942})
+               tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt})
+       case 64:
+               tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 1537228672809129302})
+               tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt})
+       }
+       for _, tt := range tests {
+               if got := tt.enc.EncodedLen(tt.n); int64(got) != tt.want {
                        t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want)
                }
        }
 }
 
 func TestDecodedLen(t *testing.T) {
-       for _, tt := range []struct {
+       type test struct {
                enc  *Encoding
                n    int
-               want int
-       }{
+               want int64
+       }
+       tests := []test{
                {RawStdEncoding, 0, 0},
                {RawStdEncoding, 2, 1},
                {RawStdEncoding, 3, 2},
@@ -299,8 +313,18 @@ func TestDecodedLen(t *testing.T) {
                {StdEncoding, 0, 0},
                {StdEncoding, 4, 3},
                {StdEncoding, 8, 6},
-       } {
-               if got := tt.enc.DecodedLen(tt.n); got != tt.want {
+       }
+       // check overflow
+       switch strconv.IntSize {
+       case 32:
+               tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 268435456})
+               tests = append(tests, test{RawStdEncoding, math.MaxInt, 1610612735})
+       case 64:
+               tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 1152921504606846976})
+               tests = append(tests, test{RawStdEncoding, math.MaxInt, 6917529027641081855})
+       }
+       for _, tt := range tests {
+               if got := tt.enc.DecodedLen(tt.n); int64(got) != tt.want {
                        t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want)
                }
        }