]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/gob: marshal maps using reflect.Value.MapRange
authorkorzhao <korzhao@tencent.com>
Wed, 4 Aug 2021 02:47:57 +0000 (02:47 +0000)
committerBryan C. Mills <bcmills@google.com>
Tue, 7 Sep 2021 19:14:23 +0000 (19:14 +0000)
golang.org/cl/33572 added a map iterator.

use the reflect.Value.MapRange to fix map keys that contain a NaN

Fixes #24075

Change-Id: I0214d6f26c2041797703e48eac16404f189d6982
GitHub-Last-Rev: 5c01e117f4451dbaec657d02d006905df1d0055d
GitHub-Pull-Request: golang/go#47476
Reviewed-on: https://go-review.googlesource.com/c/go/+/338609
Trust: Bryan C. Mills <bcmills@google.com>
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Rob Pike <r@golang.org>
src/encoding/gob/encode.go
src/encoding/gob/encoder_test.go

index 8f8f170c1643c14e84b91a2271599bf15c718c9a..f1f5f3862d4f523453106dba3878cea18618f2af 100644 (file)
@@ -368,11 +368,11 @@ func (enc *Encoder) encodeMap(b *encBuffer, mv reflect.Value, keyOp, elemOp encO
        state := enc.newEncoderState(b)
        state.fieldnum = -1
        state.sendZero = true
-       keys := mv.MapKeys()
-       state.encodeUint(uint64(len(keys)))
-       for _, key := range keys {
-               encodeReflectValue(state, key, keyOp, keyIndir)
-               encodeReflectValue(state, mv.MapIndex(key), elemOp, elemIndir)
+       state.encodeUint(uint64(mv.Len()))
+       mi := mv.MapRange()
+       for mi.Next() {
+               encodeReflectValue(state, mi.Key(), keyOp, keyIndir)
+               encodeReflectValue(state, mi.Value(), elemOp, elemIndir)
        }
        enc.freeEncoderState(state)
 }
index 6183646f60c8d39ddbf15896911e3d036848c474..6d50b825735f811fc4d4228fd3b87528f34cf1ed 100644 (file)
@@ -9,7 +9,9 @@ import (
        "encoding/hex"
        "fmt"
        "io"
+       "math"
        "reflect"
+       "sort"
        "strings"
        "testing"
 )
@@ -1152,3 +1154,51 @@ func TestDecodeErrorMultipleTypes(t *testing.T) {
                t.Errorf("decode: expected duplicate type error, got %s", err.Error())
        }
 }
+
+// Issue 24075
+func TestMarshalFloatMap(t *testing.T) {
+       nan1 := math.NaN()
+       nan2 := math.Float64frombits(math.Float64bits(nan1) ^ 1) // A different NaN in the same class.
+
+       in := map[float64]string{
+               nan1: "a",
+               nan1: "b",
+               nan2: "c",
+       }
+
+       var b bytes.Buffer
+       enc := NewEncoder(&b)
+       if err := enc.Encode(in); err != nil {
+               t.Errorf("Encode : %v", err)
+       }
+
+       out := map[float64]string{}
+       dec := NewDecoder(&b)
+       if err := dec.Decode(&out); err != nil {
+               t.Fatalf("Decode : %v", err)
+       }
+
+       type mapEntry struct {
+               keyBits uint64
+               value   string
+       }
+       readMap := func(m map[float64]string) (entries []mapEntry) {
+               for k, v := range m {
+                       entries = append(entries, mapEntry{math.Float64bits(k), v})
+               }
+               sort.Slice(entries, func(i, j int) bool {
+                       ei, ej := entries[i], entries[j]
+                       if ei.keyBits != ej.keyBits {
+                               return ei.keyBits < ej.keyBits
+                       }
+                       return ei.value < ej.value
+               })
+               return entries
+       }
+
+       got := readMap(out)
+       want := readMap(in)
+       if !reflect.DeepEqual(got, want) {
+               t.Fatalf("\nEncode: %v\nDecode: %v", want, got)
+       }
+}