]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/xml: make use of reflect.TypeAssert
authorapocelipes <seve3r@outlook.com>
Wed, 24 Sep 2025 03:23:03 +0000 (03:23 +0000)
committerGopher Robot <gobot@golang.org>
Mon, 29 Sep 2025 14:53:08 +0000 (07:53 -0700)
To make the code more readable and improve performance:

goos: darwin
goarch: arm64
pkg: encoding/xml
cpu: Apple M4
                 │     old     │                 new                 │
                 │   sec/op    │   sec/op     vs base                │
Marshal-10         1.902µ ± 1%   1.496µ ± 1%  -21.37% (p=0.000 n=10)
Unmarshal-10       3.877µ ± 1%   3.418µ ± 2%  -11.84% (p=0.000 n=10)
HTMLAutoClose-10   1.314µ ± 3%   1.333µ ± 1%        ~ (p=0.270 n=10)
geomean            2.132µ        1.896µ       -11.07%

                 │     old      │                  new                  │
                 │     B/op     │     B/op      vs base                 │
Marshal-10         5.570Ki ± 0%   5.570Ki ± 0%       ~ (p=1.000 n=10) ¹
Unmarshal-10       7.586Ki ± 0%   7.555Ki ± 0%  -0.41% (p=0.000 n=10)
HTMLAutoClose-10   3.496Ki ± 0%   3.496Ki ± 0%       ~ (p=1.000 n=10) ¹
geomean            5.286Ki        5.279Ki       -0.14%
¹ all samples are equal

                 │    old     │                 new                 │
                 │ allocs/op  │ allocs/op   vs base                 │
Marshal-10         23.00 ± 0%   23.00 ± 0%       ~ (p=1.000 n=10) ¹
Unmarshal-10       185.0 ± 0%   184.0 ± 0%  -0.54% (p=0.000 n=10)
HTMLAutoClose-10   93.00 ± 0%   93.00 ± 0%       ~ (p=1.000 n=10) ¹
geomean            73.42        73.28       -0.18%
¹ all samples are equal

Updates #62121

Change-Id: Ie458e7458d4358c380374571d380ca3b65ca87bb
GitHub-Last-Rev: bb6bb3039328ca1d53ee3d56fd6597109ed76b09
GitHub-Pull-Request: golang/go#75483
Reviewed-on: https://go-review.googlesource.com/c/go/+/704215
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Keith Randall <khr@golang.org>
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
Reviewed-by: Carlos Amedee <carlos@golang.org>
src/encoding/xml/marshal.go
src/encoding/xml/read.go

index 133503fa2de41c4b64a8077e9651b3facc11fc4c..13fbeeeedc75ced4f2575905c29dc04af8956b5b 100644 (file)
@@ -416,12 +416,6 @@ func (p *printer) popPrefix() {
        }
 }
 
-var (
-       marshalerType     = reflect.TypeFor[Marshaler]()
-       marshalerAttrType = reflect.TypeFor[MarshalerAttr]()
-       textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]()
-)
-
 // marshalValue writes one or more XML elements representing val.
 // If val was obtained from a struct field, finfo must have its details.
 func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplate *StartElement) error {
@@ -450,24 +444,32 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
        typ := val.Type()
 
        // Check for marshaler.
-       if val.CanInterface() && typ.Implements(marshalerType) {
-               return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
+       if val.CanInterface() {
+               if marshaler, ok := reflect.TypeAssert[Marshaler](val); ok {
+                       return p.marshalInterface(marshaler, defaultStart(typ, finfo, startTemplate))
+               }
        }
        if val.CanAddr() {
                pv := val.Addr()
-               if pv.CanInterface() && pv.Type().Implements(marshalerType) {
-                       return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
+               if pv.CanInterface() {
+                       if marshaler, ok := reflect.TypeAssert[Marshaler](pv); ok {
+                               return p.marshalInterface(marshaler, defaultStart(pv.Type(), finfo, startTemplate))
+                       }
                }
        }
 
        // Check for text marshaler.
-       if val.CanInterface() && typ.Implements(textMarshalerType) {
-               return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
+       if val.CanInterface() {
+               if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](val); ok {
+                       return p.marshalTextInterface(textMarshaler, defaultStart(typ, finfo, startTemplate))
+               }
        }
        if val.CanAddr() {
                pv := val.Addr()
-               if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
-                       return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
+               if pv.CanInterface() {
+                       if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](pv); ok {
+                               return p.marshalTextInterface(textMarshaler, defaultStart(pv.Type(), finfo, startTemplate))
+                       }
                }
        }
 
@@ -503,7 +505,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
                        start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
                } else {
                        fv := xmlname.value(val, dontInitNilPointers)
-                       if v, ok := fv.Interface().(Name); ok && v.Local != "" {
+                       if v, ok := reflect.TypeAssert[Name](fv); ok && v.Local != "" {
                                start.Name = v
                        }
                }
@@ -580,21 +582,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
 
 // marshalAttr marshals an attribute with the given name and value, adding to start.Attr.
 func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) error {
-       if val.CanInterface() && val.Type().Implements(marshalerAttrType) {
-               attr, err := val.Interface().(MarshalerAttr).MarshalXMLAttr(name)
-               if err != nil {
-                       return err
-               }
-               if attr.Name.Local != "" {
-                       start.Attr = append(start.Attr, attr)
-               }
-               return nil
-       }
-
-       if val.CanAddr() {
-               pv := val.Addr()
-               if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
-                       attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
+       if val.CanInterface() {
+               if marshaler, ok := reflect.TypeAssert[MarshalerAttr](val); ok {
+                       attr, err := marshaler.MarshalXMLAttr(name)
                        if err != nil {
                                return err
                        }
@@ -605,19 +595,25 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
                }
        }
 
-       if val.CanInterface() && val.Type().Implements(textMarshalerType) {
-               text, err := val.Interface().(encoding.TextMarshaler).MarshalText()
-               if err != nil {
-                       return err
+       if val.CanAddr() {
+               pv := val.Addr()
+               if pv.CanInterface() {
+                       if marshaler, ok := reflect.TypeAssert[MarshalerAttr](pv); ok {
+                               attr, err := marshaler.MarshalXMLAttr(name)
+                               if err != nil {
+                                       return err
+                               }
+                               if attr.Name.Local != "" {
+                                       start.Attr = append(start.Attr, attr)
+                               }
+                               return nil
+                       }
                }
-               start.Attr = append(start.Attr, Attr{name, string(text)})
-               return nil
        }
 
-       if val.CanAddr() {
-               pv := val.Addr()
-               if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
-                       text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
+       if val.CanInterface() {
+               if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](val); ok {
+                       text, err := textMarshaler.MarshalText()
                        if err != nil {
                                return err
                        }
@@ -626,6 +622,20 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
                }
        }
 
+       if val.CanAddr() {
+               pv := val.Addr()
+               if pv.CanInterface() {
+                       if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](pv); ok {
+                               text, err := textMarshaler.MarshalText()
+                               if err != nil {
+                                       return err
+                               }
+                               start.Attr = append(start.Attr, Attr{name, string(text)})
+                               return nil
+                       }
+               }
+       }
+
        // Dereference or skip nil pointer, interface values.
        switch val.Kind() {
        case reflect.Pointer, reflect.Interface:
@@ -647,7 +657,8 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
        }
 
        if val.Type() == attrType {
-               start.Attr = append(start.Attr, val.Interface().(Attr))
+               attr, _ := reflect.TypeAssert[Attr](val)
+               start.Attr = append(start.Attr, attr)
                return nil
        }
 
@@ -854,20 +865,9 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
                        if err := s.trim(finfo.parents); err != nil {
                                return err
                        }
-                       if vf.CanInterface() && vf.Type().Implements(textMarshalerType) {
-                               data, err := vf.Interface().(encoding.TextMarshaler).MarshalText()
-                               if err != nil {
-                                       return err
-                               }
-                               if err := emit(p, data); err != nil {
-                                       return err
-                               }
-                               continue
-                       }
-                       if vf.CanAddr() {
-                               pv := vf.Addr()
-                               if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
-                                       data, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
+                       if vf.CanInterface() {
+                               if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](vf); ok {
+                                       data, err := textMarshaler.MarshalText()
                                        if err != nil {
                                                return err
                                        }
@@ -877,6 +877,21 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
                                        continue
                                }
                        }
+                       if vf.CanAddr() {
+                               pv := vf.Addr()
+                               if pv.CanInterface() {
+                                       if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](pv); ok {
+                                               data, err := textMarshaler.MarshalText()
+                                               if err != nil {
+                                                       return err
+                                               }
+                                               if err := emit(p, data); err != nil {
+                                                       return err
+                                               }
+                                               continue
+                                       }
+                               }
+                       }
 
                        var scratch [64]byte
                        vf = indirect(vf)
@@ -902,7 +917,7 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
                                        return err
                                }
                        case reflect.Slice:
-                               if elem, ok := vf.Interface().([]byte); ok {
+                               if elem, ok := reflect.TypeAssert[[]byte](vf); ok {
                                        if err := emit(p, elem); err != nil {
                                                return err
                                        }
index af25c20f0618dc815daf4ab8cb49fd98abac6945..d3cb74b2c4311ae11dd57ef99c1c181023d812af 100644 (file)
@@ -255,28 +255,36 @@ func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
                }
                val = val.Elem()
        }
-       if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) {
+       if val.CanInterface() {
                // This is an unmarshaler with a non-pointer receiver,
                // so it's likely to be incorrect, but we do what we're told.
-               return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
+               if unmarshaler, ok := reflect.TypeAssert[UnmarshalerAttr](val); ok {
+                       return unmarshaler.UnmarshalXMLAttr(attr)
+               }
        }
        if val.CanAddr() {
                pv := val.Addr()
-               if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) {
-                       return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
+               if pv.CanInterface() {
+                       if unmarshaler, ok := reflect.TypeAssert[UnmarshalerAttr](pv); ok {
+                               return unmarshaler.UnmarshalXMLAttr(attr)
+                       }
                }
        }
 
        // Not an UnmarshalerAttr; try encoding.TextUnmarshaler.
-       if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
+       if val.CanInterface() {
                // This is an unmarshaler with a non-pointer receiver,
                // so it's likely to be incorrect, but we do what we're told.
-               return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
+               if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](val); ok {
+                       return textUnmarshaler.UnmarshalText([]byte(attr.Value))
+               }
        }
        if val.CanAddr() {
                pv := val.Addr()
-               if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
-                       return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
+               if pv.CanInterface() {
+                       if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](pv); ok {
+                               return textUnmarshaler.UnmarshalText([]byte(attr.Value))
+                       }
                }
        }
 
@@ -303,12 +311,7 @@ func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
        return copyValue(val, []byte(attr.Value))
 }
 
-var (
-       attrType            = reflect.TypeFor[Attr]()
-       unmarshalerType     = reflect.TypeFor[Unmarshaler]()
-       unmarshalerAttrType = reflect.TypeFor[UnmarshalerAttr]()
-       textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
-)
+var attrType = reflect.TypeFor[Attr]()
 
 const (
        maxUnmarshalDepth     = 10000
@@ -352,27 +355,35 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) e
                val = val.Elem()
        }
 
-       if val.CanInterface() && val.Type().Implements(unmarshalerType) {
+       if val.CanInterface() {
                // This is an unmarshaler with a non-pointer receiver,
                // so it's likely to be incorrect, but we do what we're told.
-               return d.unmarshalInterface(val.Interface().(Unmarshaler), start)
+               if unmarshaler, ok := reflect.TypeAssert[Unmarshaler](val); ok {
+                       return d.unmarshalInterface(unmarshaler, start)
+               }
        }
 
        if val.CanAddr() {
                pv := val.Addr()
-               if pv.CanInterface() && pv.Type().Implements(unmarshalerType) {
-                       return d.unmarshalInterface(pv.Interface().(Unmarshaler), start)
+               if pv.CanInterface() {
+                       if unmarshaler, ok := reflect.TypeAssert[Unmarshaler](pv); ok {
+                               return d.unmarshalInterface(unmarshaler, start)
+                       }
                }
        }
 
-       if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
-               return d.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler))
+       if val.CanInterface() {
+               if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](val); ok {
+                       return d.unmarshalTextInterface(textUnmarshaler)
+               }
        }
 
        if val.CanAddr() {
                pv := val.Addr()
-               if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
-                       return d.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler))
+               if pv.CanInterface() {
+                       if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](pv); ok {
+                               return d.unmarshalTextInterface(textUnmarshaler)
+                       }
                }
        }
 
@@ -453,7 +464,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) e
                                return UnmarshalError(e)
                        }
                        fv := finfo.value(sv, initNilPointers)
-                       if _, ok := fv.Interface().(Name); ok {
+                       if _, ok := reflect.TypeAssert[Name](fv); ok {
                                fv.Set(reflect.ValueOf(start.Name))
                        }
                }
@@ -578,20 +589,24 @@ Loop:
                }
        }
 
-       if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) {
-               if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
-                       return err
+       if saveData.IsValid() && saveData.CanInterface() {
+               if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](saveData); ok {
+                       if err := textUnmarshaler.UnmarshalText(data); err != nil {
+                               return err
+                       }
+                       saveData = reflect.Value{}
                }
-               saveData = reflect.Value{}
        }
 
        if saveData.IsValid() && saveData.CanAddr() {
                pv := saveData.Addr()
-               if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
-                       if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
-                               return err
+               if pv.CanInterface() {
+                       if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](pv); ok {
+                               if err := textUnmarshaler.UnmarshalText(data); err != nil {
+                                       return err
+                               }
+                               saveData = reflect.Value{}
                        }
-                       saveData = reflect.Value{}
                }
        }