]> Cypherpunks repositories - gostls13.git/commitdiff
reflect: add FieldByNameFunc
authorRaif S. Naffah <go@naffah-raif.name>
Sun, 18 Apr 2010 22:22:36 +0000 (15:22 -0700)
committerRuss Cox <rsc@golang.org>
Sun, 18 Apr 2010 22:22:36 +0000 (15:22 -0700)
xml: add support for XML marshalling embedded structs.

R=rsc
CC=golang-dev
https://golang.org/cl/837042

src/pkg/reflect/type.go
src/pkg/reflect/value.go
src/pkg/xml/embed_test.go [new file with mode: 0644]
src/pkg/xml/read.go

index a8df033af4983a6afe764d43db573d7938572022..eb1ba52a9f2bc06c597ff496390deb4de822a8d5 100644 (file)
@@ -507,7 +507,7 @@ func (t *StructType) FieldByIndex(index []int) (f StructField) {
 
 const inf = 1 << 30 // infinity - no struct has that many nesting levels
 
-func (t *StructType) fieldByName(name string, mark map[*StructType]bool, depth int) (ff StructField, fd int) {
+func (t *StructType) fieldByNameFunc(match func(string) bool, mark map[*StructType]bool, depth int) (ff StructField, fd int) {
        fd = inf // field depth
 
        if mark[t] {
@@ -522,7 +522,7 @@ L: for i, _ := range t.fields {
                f := t.Field(i)
                d := inf
                switch {
-               case f.Name == name:
+               case match(f.Name):
                        // Matching top-level field.
                        d = depth
                case f.Anonymous:
@@ -531,13 +531,13 @@ L: for i, _ := range t.fields {
                                ft = pt.Elem()
                        }
                        switch {
-                       case ft.Name() == name:
+                       case match(ft.Name()):
                                // Matching anonymous top-level field.
                                d = depth
                        case fd > depth:
                                // No top-level field yet; look inside nested structs.
                                if st, ok := ft.(*StructType); ok {
-                                       f, d = st.fieldByName(name, mark, depth+1)
+                                       f, d = st.fieldByNameFunc(match, mark, depth+1)
                                }
                        }
                }
@@ -576,7 +576,13 @@ L: for i, _ := range t.fields {
 // FieldByName returns the struct field with the given name
 // and a boolean to indicate if the field was found.
 func (t *StructType) FieldByName(name string) (f StructField, present bool) {
-       if ff, fd := t.fieldByName(name, make(map[*StructType]bool), 0); fd < inf {
+       return t.FieldByNameFunc(func(s string) bool { return s == name })
+}
+
+// FieldByNameFunc returns the struct field with a name that satisfies the
+// match function and a boolean to indicate if the field was found.
+func (t *StructType) FieldByNameFunc(match func(string) bool) (f StructField, present bool) {
+       if ff, fd := t.fieldByNameFunc(match, make(map[*StructType]bool), 0); fd < inf {
                ff.Index = ff.Index[0 : fd+1]
                f, present = ff, true
        }
index f21c564d537c3b76e5743022cc66e40e5455051c..d8ddb289a4075351b7af660616bf8f54169fdc58 100644 (file)
@@ -1251,6 +1251,16 @@ func (t *StructValue) FieldByName(name string) Value {
        return nil
 }
 
+// FieldByNameFunc returns the struct field with a name that satisfies the
+// match function.
+// The result is nil if no field was found.
+func (t *StructValue) FieldByNameFunc(match func(string) bool) Value {
+       if f, ok := t.Type().(*StructType).FieldByNameFunc(match); ok {
+               return t.FieldByIndex(f.Index)
+       }
+       return nil
+}
+
 // NumField returns the number of fields in the struct.
 func (v *StructValue) NumField() int { return v.typ.(*StructType).NumField() }
 
diff --git a/src/pkg/xml/embed_test.go b/src/pkg/xml/embed_test.go
new file mode 100644 (file)
index 0000000..abfe781
--- /dev/null
@@ -0,0 +1,124 @@
+// Copyright 2010 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xml
+
+import "testing"
+
+type C struct {
+       Name string
+       Open bool
+}
+
+type A struct {
+       XMLName Name "http://domain a"
+       C
+       B      B
+       FieldA string
+}
+
+type B struct {
+       XMLName Name "b"
+       C
+       FieldB string
+}
+
+const _1a = `
+<?xml version="1.0" encoding="UTF-8"?>
+<a xmlns="http://domain">
+  <name>KmlFile</name>
+  <open>1</open>
+  <b>
+    <name>Absolute</name>
+    <open>0</open>
+    <fieldb>bar</fieldb>
+  </b>
+  <fielda>foo</fielda>
+</a>
+`
+
+// Tests that embedded structs are marshalled.
+func TestEmbedded1(t *testing.T) {
+       var a A
+       if e := Unmarshal(StringReader(_1a), &a); e != nil {
+               t.Fatalf("Unmarshal: %s", e)
+       }
+       if a.FieldA != "foo" {
+               t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.FieldA)
+       }
+       if a.Name != "KmlFile" {
+               t.Fatalf("Unmarshal: expected 'KmlFile' but found '%s'", a.Name)
+       }
+       if !a.Open {
+               t.Fatal("Unmarshal: expected 'true' but found otherwise")
+       }
+       if a.B.FieldB != "bar" {
+               t.Fatalf("Unmarshal: expected 'bar' but found '%s'", a.B.FieldB)
+       }
+       if a.B.Name != "Absolute" {
+               t.Fatalf("Unmarshal: expected 'Absolute' but found '%s'", a.B.Name)
+       }
+       if a.B.Open {
+               t.Fatal("Unmarshal: expected 'false' but found otherwise")
+       }
+}
+
+type A2 struct {
+       XMLName Name "http://domain a"
+       XY      string
+       Xy      string
+}
+
+const _2a = `
+<?xml version="1.0" encoding="UTF-8"?>
+<a xmlns="http://domain">
+  <xy>foo</xy>
+</a>
+`
+
+// Tests that conflicting field names get excluded.
+func TestEmbedded2(t *testing.T) {
+       var a A2
+       if e := Unmarshal(StringReader(_2a), &a); e != nil {
+               t.Fatalf("Unmarshal: %s", e)
+       }
+       if a.XY != "" {
+               t.Fatalf("Unmarshal: expected empty string but found '%s'", a.XY)
+       }
+       if a.Xy != "" {
+               t.Fatalf("Unmarshal: expected empty string but found '%s'", a.Xy)
+       }
+}
+
+type A3 struct {
+       XMLName Name "http://domain a"
+       xy      string
+}
+
+// Tests that private fields are not set.
+func TestEmbedded3(t *testing.T) {
+       var a A3
+       if e := Unmarshal(StringReader(_2a), &a); e != nil {
+               t.Fatalf("Unmarshal: %s", e)
+       }
+       if a.xy != "" {
+               t.Fatalf("Unmarshal: expected empty string but found '%s'", a.xy)
+       }
+}
+
+type A4 struct {
+       XMLName Name "http://domain a"
+       Any     string
+}
+
+// Tests that private fields are not set.
+func TestEmbedded4(t *testing.T) {
+       var a A4
+       if e := Unmarshal(StringReader(_2a), &a); e != nil {
+               t.Fatalf("Unmarshal: %s", e)
+       }
+       if a.Any != "foo" {
+               t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.Any)
+       }
+}
index 9eb0be2538ba0f635405dc7d5456a7af8a1e6925..45db7daa36500328b1db2cccaeb6e6bd5cd99f0a 100644 (file)
@@ -12,6 +12,7 @@ import (
        "strconv"
        "strings"
        "unicode"
+       "utf8"
 )
 
 // BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
@@ -331,24 +332,24 @@ Loop:
                case StartElement:
                        // Sub-element.
                        // Look up by tag name.
-                       // If that fails, fall back to mop-up field named "Any".
                        if sv != nil {
                                k := fieldName(t.Name.Local)
-                               any := -1
-                               for i, n := 0, styp.NumField(); i < n; i++ {
-                                       f := styp.Field(i)
-                                       if strings.ToLower(f.Name) == k {
-                                               if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil {
-                                                       return err
-                                               }
-                                               continue Loop
-                                       }
-                                       if any < 0 && f.Name == "Any" {
-                                               any = i
+                               match := func(s string) bool {
+                                       // check if the name matches ignoring case
+                                       if strings.ToLower(s) != strings.ToLower(k) {
+                                               return false
                                        }
+                                       // now check that it's public
+                                       c, _ := utf8.DecodeRuneInString(s)
+                                       return unicode.IsUpper(c)
+                               }
+
+                               f, found := styp.FieldByNameFunc(match)
+                               if !found { // fall back to mop-up field named "Any"
+                                       f, found = styp.FieldByName("Any")
                                }
-                               if any >= 0 {
-                                       if err := p.unmarshal(sv.FieldByIndex(styp.Field(any).Index), &t); err != nil {
+                               if found {
+                                       if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil {
                                                return err
                                        }
                                        continue Loop