]> Cypherpunks repositories - gostls13.git/commitdiff
allow any scalar type in xml.Unmarshal.
authorRob Pike <r@golang.org>
Tue, 2 Feb 2010 00:53:10 +0000 (11:53 +1100)
committerRob Pike <r@golang.org>
Tue, 2 Feb 2010 00:53:10 +0000 (11:53 +1100)
Fixes #574.

R=rsc
CC=golang-dev
https://golang.org/cl/196056

src/pkg/xml/read.go
src/pkg/xml/xml_test.go

index 4f944038e878e7bd5629fef6e929f98c73734463..c85b697025a66eb2fe6edf904c02e8f6e13013d4 100644 (file)
@@ -9,6 +9,7 @@ import (
        "io"
        "os"
        "reflect"
+       "strconv"
        "strings"
        "unicode"
 )
@@ -106,6 +107,10 @@ import (
 //
 // Unmarshal maps an XML element to a bool by setting the bool to true.
 //
+// Unmarshal maps an XML element to an integer or floating-point
+// field by setting the field to the result of interpreting the string
+// value in decimal.  There is no check for overflow.
+//
 // Unmarshal maps an XML element to an xml.Name by recording the
 // element name.
 //
@@ -200,6 +205,9 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error {
                styp        *reflect.StructType
        )
        switch v := val.(type) {
+       default:
+               return os.ErrorString("unknown type " + v.Type().String())
+
        case *reflect.BoolValue:
                v.Set(true)
 
@@ -232,7 +240,11 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error {
                }
                return nil
 
-       case *reflect.StringValue:
+       case *reflect.StringValue,
+               *reflect.IntValue, *reflect.UintValue, *reflect.UintptrValue,
+               *reflect.Int8Value, *reflect.Int16Value, *reflect.Int32Value, *reflect.Int64Value,
+               *reflect.Uint8Value, *reflect.Uint16Value, *reflect.Uint32Value, *reflect.Uint64Value,
+               *reflect.FloatValue, *reflect.Float32Value, *reflect.Float64Value:
                saveData = v
 
        case *reflect.StructValue:
@@ -365,8 +377,103 @@ Loop:
                }
        }
 
-       // Save accumulated character data and comments
+       var err os.Error
+       // Helper functions for integer and unsigned integer conversions
+       var itmp int64
+       getInt64 := func() bool {
+               itmp, err = strconv.Atoi64(string(data))
+               // TODO: should check sizes
+               return err == nil
+       }
+       var utmp uint64
+       getUint64 := func() bool {
+               utmp, err = strconv.Atoui64(string(data))
+               // TODO: check for overflow?
+               return err == nil
+       }
+       var ftmp float64
+       getFloat64 := func() bool {
+               ftmp, err = strconv.Atof64(string(data))
+               // TODO: check for overflow?
+               return err == nil
+       }
+
+       // Save accumulated data and comments
        switch t := saveData.(type) {
+       case nil:
+               // Probably a comment, handled below
+       default:
+               return os.ErrorString("cannot happen: unknown type " + t.Type().String())
+       case *reflect.IntValue:
+               if !getInt64() {
+                       return err
+               }
+               t.Set(int(itmp))
+       case *reflect.Int8Value:
+               if !getInt64() {
+                       return err
+               }
+               t.Set(int8(itmp))
+       case *reflect.Int16Value:
+               if !getInt64() {
+                       return err
+               }
+               t.Set(int16(itmp))
+       case *reflect.Int32Value:
+               if !getInt64() {
+                       return err
+               }
+               t.Set(int32(itmp))
+       case *reflect.Int64Value:
+               if !getInt64() {
+                       return err
+               }
+               t.Set(itmp)
+       case *reflect.UintValue:
+               if !getUint64() {
+                       return err
+               }
+               t.Set(uint(utmp))
+       case *reflect.Uint8Value:
+               if !getUint64() {
+                       return err
+               }
+               t.Set(uint8(utmp))
+       case *reflect.Uint16Value:
+               if !getUint64() {
+                       return err
+               }
+               t.Set(uint16(utmp))
+       case *reflect.Uint32Value:
+               if !getUint64() {
+                       return err
+               }
+               t.Set(uint32(utmp))
+       case *reflect.Uint64Value:
+               if !getUint64() {
+                       return err
+               }
+               t.Set(utmp)
+       case *reflect.UintptrValue:
+               if !getUint64() {
+                       return err
+               }
+               t.Set(uintptr(utmp))
+       case *reflect.FloatValue:
+               if !getFloat64() {
+                       return err
+               }
+               t.Set(float(ftmp))
+       case *reflect.Float32Value:
+               if !getFloat64() {
+                       return err
+               }
+               t.Set(float32(ftmp))
+       case *reflect.Float64Value:
+               if !getFloat64() {
+                       return err
+               }
+               t.Set(ftmp)
        case *reflect.StringValue:
                t.Set(string(data))
        case *reflect.SliceValue:
index f228dfba3709e549685672bcdf22548185a48bc6..fa19495001ea11d646e06e6fa72db77cd60fab4c 100644 (file)
@@ -214,6 +214,76 @@ func TestSyntax(t *testing.T) {
        }
 }
 
+type allScalars struct {
+       Bool    bool
+       Int     int
+       Int8    int8
+       Int16   int16
+       Int32   int32
+       Int64   int64
+       Uint    int
+       Uint8   uint8
+       Uint16  uint16
+       Uint32  uint32
+       Uint64  uint64
+       Uintptr uintptr
+       Float   float
+       Float32 float32
+       Float64 float64
+       String  string
+}
+
+var all = allScalars{
+       Bool: true,
+       Int: 1,
+       Int8: -2,
+       Int16: 3,
+       Int32: -4,
+       Int64: 5,
+       Uint: 6,
+       Uint8: 7,
+       Uint16: 8,
+       Uint32: 9,
+       Uint64: 10,
+       Uintptr: 11,
+       Float: 12.0,
+       Float32: 13.0,
+       Float64: 14.0,
+       String: "15",
+}
+
+const testScalarsInput = `<allscalars>
+       <bool/>
+       <int>1</int>
+       <int8>-2</int8>
+       <int16>3</int16>
+       <int32>-4</int32>
+       <int64>5</int64>
+       <uint>6</uint>
+       <uint8>7</uint8>
+       <uint16>8</uint16>
+       <uint32>9</uint32>
+       <uint64>10</uint64>
+       <uintptr>11</uintptr>
+       <float>12.0</float>
+       <float32>13.0</float32>
+       <float64>14.0</float64>
+       <string>15</string>
+</allscalars>`
+
+func TestAllScalars(t *testing.T) {
+       var a allScalars
+       buf := bytes.NewBufferString(testScalarsInput)
+       err := Unmarshal(buf, &a)
+
+       if err != nil {
+               t.Fatal(err)
+       }
+       if !reflect.DeepEqual(a, all) {
+               t.Errorf("expected %+v got %+v", a, all)
+       }
+}
+
 type item struct {
        Field_a string
 }