]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/base32: reduce overflow risk when calc encode/decode len
authorchanxuehong <chanxuehong@gmail.com>
Mon, 24 Jul 2023 08:58:16 +0000 (08:58 +0000)
committerGopher Robot <gobot@golang.org>
Wed, 26 Jul 2023 11:18:30 +0000 (11:18 +0000)
Same as https://go-review.googlesource.com/c/go/+/510635, reduces risk of overflow

Change-Id: I18f5560d73af76c3e853464a89ad7e42dbbd5894
GitHub-Last-Rev: 652c8c6712886184e59a110c3fa1e6dcb643d93b
GitHub-Pull-Request: golang/go#61547
Reviewed-on: https://go-review.googlesource.com/c/go/+/512200
Auto-Submit: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
src/encoding/base32/base32.go
src/encoding/base32/base32_test.go

index 3dc37b0aa7908f39ac539a365d9ffb911a199be3..a4d515edbd7d4980b4e2cd1e247aa3ed6a323161 100644 (file)
@@ -271,7 +271,7 @@ func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
 // of an input buffer of length n.
 func (enc *Encoding) EncodedLen(n int) int {
        if enc.padChar == NoPadding {
-               return (n*8 + 4) / 5
+               return n/5*8 + (n%5*8+4)/5
        }
        return (n + 4) / 5 * 8
 }
@@ -545,8 +545,7 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
 // corresponding to n bytes of base32-encoded data.
 func (enc *Encoding) DecodedLen(n int) int {
        if enc.padChar == NoPadding {
-               return n * 5 / 8
+               return n/8*5 + n%8*5/8
        }
-
        return n / 8 * 5
 }
index 8118531b383e6559b963efccb50d954a0f6cf677..bdb9f0e61f86de98d57cb1b33e3cf1968661752d 100644 (file)
@@ -8,6 +8,8 @@ import (
        "bytes"
        "errors"
        "io"
+       "math"
+       "strconv"
        "strings"
        "testing"
 )
@@ -679,52 +681,86 @@ func TestBufferedDecodingPadding(t *testing.T) {
        }
 }
 
-func TestEncodedDecodedLen(t *testing.T) {
+func TestEncodedLen(t *testing.T) {
+       var rawStdEncoding = StdEncoding.WithPadding(NoPadding)
        type test struct {
-               in      int
-               wantEnc int
-               wantDec int
-       }
-       data := bytes.Repeat([]byte("x"), 100)
-       for _, test := range []struct {
-               name  string
-               enc   *Encoding
-               cases []test
-       }{
-               {"StdEncoding", StdEncoding, []test{
-                       {0, 0, 0},
-                       {1, 8, 5},
-                       {5, 8, 5},
-                       {6, 16, 10},
-                       {10, 16, 10},
-               }},
-               {"NoPadding", StdEncoding.WithPadding(NoPadding), []test{
-                       {0, 0, 0},
-                       {1, 2, 1},
-                       {2, 4, 2},
-                       {5, 8, 5},
-                       {6, 10, 6},
-                       {7, 12, 7},
-                       {10, 16, 10},
-                       {11, 18, 11},
-               }},
-       } {
-               t.Run(test.name, func(t *testing.T) {
-                       for _, tc := range test.cases {
-                               encLen := test.enc.EncodedLen(tc.in)
-                               decLen := test.enc.DecodedLen(encLen)
-                               enc := test.enc.EncodeToString(data[:tc.in])
-                               if len(enc) != encLen {
-                                       t.Fatalf("EncodedLen(%d) = %d but encoded to %q (%d)", tc.in, encLen, enc, len(enc))
-                               }
-                               if encLen != tc.wantEnc {
-                                       t.Fatalf("EncodedLen(%d) = %d; want %d", tc.in, encLen, tc.wantEnc)
-                               }
-                               if decLen != tc.wantDec {
-                                       t.Fatalf("DecodedLen(%d) = %d; want %d", encLen, decLen, tc.wantDec)
-                               }
-                       }
-               })
+               enc  *Encoding
+               n    int
+               want int64
+       }
+       tests := []test{
+               {StdEncoding, 0, 0},
+               {StdEncoding, 1, 8},
+               {StdEncoding, 2, 8},
+               {StdEncoding, 3, 8},
+               {StdEncoding, 4, 8},
+               {StdEncoding, 5, 8},
+               {StdEncoding, 6, 16},
+               {StdEncoding, 10, 16},
+               {StdEncoding, 11, 24},
+               {rawStdEncoding, 0, 0},
+               {rawStdEncoding, 1, 2},
+               {rawStdEncoding, 2, 4},
+               {rawStdEncoding, 3, 5},
+               {rawStdEncoding, 4, 7},
+               {rawStdEncoding, 5, 8},
+               {rawStdEncoding, 6, 10},
+               {rawStdEncoding, 7, 12},
+               {rawStdEncoding, 10, 16},
+               {rawStdEncoding, 11, 18},
+       }
+       // check overflow
+       switch strconv.IntSize {
+       case 32:
+               tests = append(tests, test{rawStdEncoding, (math.MaxInt-4)/8 + 1, 429496730})
+               tests = append(tests, test{rawStdEncoding, math.MaxInt/8*5 + 4, math.MaxInt})
+       case 64:
+               tests = append(tests, test{rawStdEncoding, (math.MaxInt-4)/8 + 1, 1844674407370955162})
+               tests = append(tests, test{rawStdEncoding, math.MaxInt/8*5 + 4, math.MaxInt})
+       }
+       for _, tt := range tests {
+               if got := tt.enc.EncodedLen(tt.n); int64(got) != tt.want {
+                       t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want)
+               }
+       }
+}
+
+func TestDecodedLen(t *testing.T) {
+       var rawStdEncoding = StdEncoding.WithPadding(NoPadding)
+       type test struct {
+               enc  *Encoding
+               n    int
+               want int64
+       }
+       tests := []test{
+               {StdEncoding, 0, 0},
+               {StdEncoding, 8, 5},
+               {StdEncoding, 16, 10},
+               {StdEncoding, 24, 15},
+               {rawStdEncoding, 0, 0},
+               {rawStdEncoding, 2, 1},
+               {rawStdEncoding, 4, 2},
+               {rawStdEncoding, 5, 3},
+               {rawStdEncoding, 7, 4},
+               {rawStdEncoding, 8, 5},
+               {rawStdEncoding, 10, 6},
+               {rawStdEncoding, 12, 7},
+               {rawStdEncoding, 16, 10},
+               {rawStdEncoding, 18, 11},
+       }
+       // check overflow
+       switch strconv.IntSize {
+       case 32:
+               tests = append(tests, test{rawStdEncoding, math.MaxInt/5 + 1, 268435456})
+               tests = append(tests, test{rawStdEncoding, math.MaxInt, 1342177279})
+       case 64:
+               tests = append(tests, test{rawStdEncoding, math.MaxInt/5 + 1, 1152921504606846976})
+               tests = append(tests, test{rawStdEncoding, math.MaxInt, 5764607523034234879})
+       }
+       for _, tt := range tests {
+               if got := tt.enc.DecodedLen(tt.n); int64(got) != tt.want {
+                       t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want)
+               }
        }
 }