]> Cypherpunks repositories - gostls13.git/commitdiff
fmt.Scan: refactor the implementation so format-driven and normal scanning use the...
authorRob Pike <r@golang.org>
Mon, 31 May 2010 21:53:15 +0000 (14:53 -0700)
committerRob Pike <r@golang.org>
Mon, 31 May 2010 21:53:15 +0000 (14:53 -0700)
simplifies the code significantly.
Still TODO:
- proper format handling
- strings

R=rsc
CC=golang-dev
https://golang.org/cl/1432041

src/pkg/fmt/scan.go
src/pkg/fmt/scan_test.go

index ec7ba9bf591b43d15085eda2b14ae61561cf27d8..9851d4d2951b92906fa08b4db76ac6b028e64820 100644 (file)
@@ -251,9 +251,23 @@ var intBits = uint(reflect.Typeof(int(0)).Size() * 8)
 var uintptrBits = uint(reflect.Typeof(int(0)).Size() * 8)
 var complexError = os.ErrorString("syntax error scanning complex number")
 
+// okVerb verifies that the verb is present in the list, setting s.err appropriately if not.
+func (s *ss) okVerb(verb int, okVerbs, typ string) bool {
+       if s.err != nil { // don't overwrite error
+               return false
+       }
+       for _, v := range okVerbs {
+               if v == verb {
+                       return true
+               }
+       }
+       s.err = os.ErrorString("bad verb %" + string(verb) + " for " + typ)
+       return false
+}
+
 // scanBool converts the token to a boolean value.
-func (s *ss) scanBool(tok string) bool {
-       if s.err != nil {
+func (s *ss) scanBool(verb int, tok string) bool {
+       if !s.okVerb(verb, "tv", "boolean") {
                return false
        }
        var b bool
@@ -261,9 +275,27 @@ func (s *ss) scanBool(tok string) bool {
        return b
 }
 
+func (s *ss) getBase(verb int) int {
+       s.okVerb(verb, "bdoxXv", "integer") // sets s.err
+       base := 10
+       switch verb {
+       case 'b':
+               base = 2
+       case 'o':
+               base = 8
+       case 'x', 'X':
+               base = 16
+       }
+       return base
+}
+
 // convertInt returns the value of the integer
 // stored in the token, checking for overflow.  Any error is stored in s.err.
-func (s *ss) convertInt(tok string, bitSize uint, base int) (i int64) {
+func (s *ss) convertInt(verb int, tok string, bitSize uint) (i int64) {
+       base := s.getBase(verb)
+       if s.err != nil {
+               return 0
+       }
        i, s.err = strconv.Btoi64(tok, base)
        x := (i << (64 - bitSize)) >> (64 - bitSize)
        if x != i {
@@ -274,7 +306,11 @@ func (s *ss) convertInt(tok string, bitSize uint, base int) (i int64) {
 
 // convertUint returns the value of the unsigned integer
 // stored in the token, checking for overflow.  Any error is stored in s.err.
-func (s *ss) convertUint(tok string, bitSize uint, base int) (i uint64) {
+func (s *ss) convertUint(verb int, tok string, bitSize uint) (i uint64) {
+       base := s.getBase(verb)
+       if s.err != nil {
+               return 0
+       }
        i, s.err = strconv.Btoui64(tok, base)
        x := (i << (64 - bitSize)) >> (64 - bitSize)
        if x != i {
@@ -283,79 +319,6 @@ func (s *ss) convertUint(tok string, bitSize uint, base int) (i uint64) {
        return i
 }
 
-// scanInteger converts the token to an integer in the appropriate base
-// and stores the result according to the type of the field.
-func (s *ss) scanInteger(tok string, field interface{}, base int) {
-       switch v := field.(type) {
-       case *int:
-               *v = int(s.convertInt(tok, intBits, base))
-               return
-       case *int8:
-               *v = int8(s.convertInt(tok, 8, base))
-               return
-       case *int16:
-               *v = int16(s.convertInt(tok, 16, base))
-               return
-       case *int32:
-               *v = int32(s.convertInt(tok, 32, base))
-               return
-       case *int64:
-               *v = s.convertInt(tok, 64, base)
-               return
-       case *uint:
-               *v = uint(s.convertUint(tok, intBits, base))
-               return
-       case *uint8:
-               *v = uint8(s.convertUint(tok, 8, base))
-               return
-       case *uint16:
-               *v = uint16(s.convertUint(tok, 16, base))
-               return
-       case *uint32:
-               *v = uint32(s.convertUint(tok, 32, base))
-               return
-       case *uint64:
-               *v = uint64(s.convertUint(tok, 64, base))
-               return
-       case *uintptr:
-               *v = uintptr(s.convertUint(tok, uintptrBits, base))
-               return
-       }
-       // Not a basic type; probably a renamed type. We need to use reflection.
-       v := reflect.NewValue(field)
-       ptr, ok := v.(*reflect.PtrValue)
-       if !ok {
-               s.typeError(field, "integer")
-               return
-       }
-       switch v := ptr.Elem().(type) {
-       case *reflect.IntValue:
-               v.Set(int(s.convertInt(tok, intBits, base)))
-       case *reflect.Int8Value:
-               v.Set(int8(s.convertInt(tok, 8, base)))
-       case *reflect.Int16Value:
-               v.Set(int16(s.convertInt(tok, 16, base)))
-       case *reflect.Int32Value:
-               v.Set(int32(s.convertInt(tok, 32, base)))
-       case *reflect.Int64Value:
-               v.Set(s.convertInt(tok, 64, base))
-       case *reflect.UintValue:
-               v.Set(uint(s.convertUint(tok, intBits, base)))
-       case *reflect.Uint8Value:
-               v.Set(uint8(s.convertUint(tok, 8, base)))
-       case *reflect.Uint16Value:
-               v.Set(uint16(s.convertUint(tok, 16, base)))
-       case *reflect.Uint32Value:
-               v.Set(uint32(s.convertUint(tok, 32, base)))
-       case *reflect.Uint64Value:
-               v.Set(s.convertUint(tok, 64, base))
-       case *reflect.UintptrValue:
-               v.Set(uintptr(s.convertUint(tok, uintptrBits, base)))
-       default:
-               s.err = os.ErrorString("internal error: unknown int type")
-       }
-}
-
 // complexParts returns the strings representing the real and imaginary parts of the string.
 func (s *ss) complexParts(str string) (real, imag string) {
        if len(str) > 2 && str[0] == '(' && str[len(str)-1] == ')' {
@@ -436,9 +399,6 @@ func (s *ss) scanFloat64(str string) float64 {
 // If we're reading complex64, atof will parse float32s and convert them
 // to float64's to avoid reproducing this code for each complex type.
 func (s *ss) scanComplex(tok string, atof func(*ss, string) float64) complex128 {
-       if s.err != nil {
-               return 0
-       }
        sreal, simag := s.complexParts(tok)
        if s.err != nil {
                return 0
@@ -455,15 +415,22 @@ func (s *ss) scanComplex(tok string, atof func(*ss, string) float64) complex128
        return cmplx(real, imag)
 }
 
+const floatVerbs = "eEfFgGv"
+
 // scanOne scans a single value, deriving the scanner from the type of the argument.
-func (s *ss) scanOne(field interface{}) {
+func (s *ss) scanOne(verb int, field interface{}) {
+       // If the parameter has its own Scan method, use that.
+       if v, ok := field.(Scanner); ok {
+               s.err = v.Scan(s)
+               return
+       }
        tok := s.token()
        if s.err != nil {
                return
        }
        switch v := field.(type) {
        case *bool:
-               *v = s.scanBool(tok)
+               *v = s.scanBool(verb, tok)
        case *complex:
                *v = complex(s.scanComplex(tok, (*ss).scanFloat))
        case *complex64:
@@ -471,62 +438,73 @@ func (s *ss) scanOne(field interface{}) {
        case *complex128:
                *v = s.scanComplex(tok, (*ss).scanFloat64)
        case *int:
-               *v = int(s.convertInt(tok, intBits, 10))
+               *v = int(s.convertInt(verb, tok, intBits))
        case *int8:
-               *v = int8(s.convertInt(tok, 8, 10))
+               *v = int8(s.convertInt(verb, tok, 8))
        case *int16:
-               *v = int16(s.convertInt(tok, 16, 10))
+               *v = int16(s.convertInt(verb, tok, 16))
        case *int32:
-               *v = int32(s.convertInt(tok, 32, 10))
+               *v = int32(s.convertInt(verb, tok, 32))
        case *int64:
-               *v = s.convertInt(tok, intBits, 10)
+               *v = s.convertInt(verb, tok, intBits)
        case *uint:
-               *v = uint(s.convertUint(tok, intBits, 10))
+               *v = uint(s.convertUint(verb, tok, intBits))
        case *uint8:
-               *v = uint8(s.convertUint(tok, 8, 10))
+               *v = uint8(s.convertUint(verb, tok, 8))
        case *uint16:
-               *v = uint16(s.convertUint(tok, 16, 10))
+               *v = uint16(s.convertUint(verb, tok, 16))
        case *uint32:
-               *v = uint32(s.convertUint(tok, 32, 10))
+               *v = uint32(s.convertUint(verb, tok, 32))
        case *uint64:
-               *v = s.convertUint(tok, 64, 10)
+               *v = s.convertUint(verb, tok, 64)
        case *uintptr:
-               *v = uintptr(s.convertUint(tok, uintptrBits, 10))
+               *v = uintptr(s.convertUint(verb, tok, uintptrBits))
        case *float:
-               if s.err == nil {
+               if s.okVerb(verb, floatVerbs, "float") {
                        *v, s.err = strconv.Atof(tok)
-               } else {
-                       *v = 0
                }
        case *float32:
-               if s.err == nil {
+               if s.okVerb(verb, floatVerbs, "float32") {
                        *v, s.err = strconv.Atof32(tok)
-               } else {
-                       *v = 0
                }
        case *float64:
-               if s.err == nil {
+               if s.okVerb(verb, floatVerbs, "float64") {
                        *v, s.err = strconv.Atof64(tok)
-               } else {
-                       *v = 0
                }
        case *string:
                *v = tok
        default:
-               t := reflect.Typeof(v)
-               str := t.String()
-               ptr, ok := t.(*reflect.PtrType)
+               val := reflect.NewValue(v)
+               ptr, ok := val.(*reflect.PtrValue)
                if !ok {
-                       s.err = os.ErrorString("Scan: type not a pointer: " + str)
+                       s.err = os.ErrorString("Scan: type not a pointer: " + val.Type().String())
                        return
                }
-               switch ptr.Elem().(type) {
-               case *reflect.IntType, *reflect.Int8Type, *reflect.Int16Type, *reflect.Int32Type, *reflect.Int64Type:
-                       s.scanInteger(tok, v, 10)
-               case *reflect.UintType, *reflect.Uint8Type, *reflect.Uint16Type, *reflect.Uint32Type, *reflect.Uint64Type, *reflect.UintptrType:
-                       s.scanInteger(tok, v, 10)
+               switch v := ptr.Elem().(type) {
+               case *reflect.IntValue:
+                       v.Set(int(s.convertInt(verb, tok, intBits)))
+               case *reflect.Int8Value:
+                       v.Set(int8(s.convertInt(verb, tok, 8)))
+               case *reflect.Int16Value:
+                       v.Set(int16(s.convertInt(verb, tok, 16)))
+               case *reflect.Int32Value:
+                       v.Set(int32(s.convertInt(verb, tok, 32)))
+               case *reflect.Int64Value:
+                       v.Set(s.convertInt(verb, tok, 64))
+               case *reflect.UintValue:
+                       v.Set(uint(s.convertUint(verb, tok, intBits)))
+               case *reflect.Uint8Value:
+                       v.Set(uint8(s.convertUint(verb, tok, 8)))
+               case *reflect.Uint16Value:
+                       v.Set(uint16(s.convertUint(verb, tok, 16)))
+               case *reflect.Uint32Value:
+                       v.Set(uint32(s.convertUint(verb, tok, 32)))
+               case *reflect.Uint64Value:
+                       v.Set(s.convertUint(verb, tok, 64))
+               case *reflect.UintptrValue:
+                       v.Set(uintptr(s.convertUint(verb, tok, uintptrBits)))
                default:
-                       s.err = os.ErrorString("Scan: can't handle type: " + t.String())
+                       s.err = os.ErrorString("Scan: can't handle type: " + val.Type().String())
                }
        }
 }
@@ -535,15 +513,7 @@ func (s *ss) scanOne(field interface{}) {
 // At the moment, it handles only pointers to basic types.
 func (s *ss) doScan(a []interface{}) int {
        for fieldnum, field := range a {
-               // If the parameter has its own Scan method, use that.
-               if v, ok := field.(Scanner); ok {
-                       s.err = v.Scan(s)
-                       if s.err != nil {
-                               return fieldnum
-                       }
-                       continue
-               }
-               s.scanOne(field)
+               s.scanOne('v', field)
                if s.err != nil {
                        return fieldnum
                }
@@ -592,48 +562,18 @@ func (s *ss) doScanf(format string, a []interface{}) int {
                        // TODO: WHAT NOW?
                        continue
                }
+
                if fieldnum >= len(a) { // out of operands
                        s.err = os.ErrorString("too few operands for format %" + format[i-w:])
-                       return fieldnum
+                       break
                }
                field := a[fieldnum]
-               fieldnum++
 
-               // If the parameter has its own Scan method, use that.
-               if v, ok := field.(Scanner); ok {
-                       s.err = v.Scan(s)
-                       if s.err != nil {
-                               return fieldnum - 1
-                       }
-                       continue
-               }
-               if c == 'v' {
-                       // Default format works; just call doScan, but note that it will scan for the token
-                       s.scanOne(field)
-               } else {
-                       tok := s.token()
-                       switch c {
-                       case 't':
-                               if v, ok := field.(*bool); ok {
-                                       *v = s.scanBool(tok)
-                               } else {
-                                       s.typeError(field, "boolean")
-                               }
-                       case 'b':
-                               s.scanInteger(tok, field, 2)
-                       case 'o':
-                               s.scanInteger(tok, field, 8)
-                       case 'd':
-                               s.scanInteger(tok, field, 10)
-                       case 'x', 'X':
-                               s.scanInteger(tok, field, 16)
-                       default:
-                               s.err = os.ErrorString("unknown scanning verb %" + format[i-w:])
-                       }
-                       if s.err != nil {
-                               return fieldnum - 1
-                       }
+               s.scanOne(c, field)
+               if s.err != nil {
+                       break
                }
+               fieldnum++
        }
        return fieldnum
 }
index ca51cf0a219d6c311f00c026b0889abfd35a440f..55808e964dcba42c0457554b77a412cb83357291 100644 (file)
@@ -170,15 +170,15 @@ var scanfTests = []ScanfTest{
 
        // Renamed types
        ScanfTest{"%v", "101\n", &renamedIntVal, renamedInt(101)},
-       ScanfTest{"%d", "102\n", &renamedIntVal, renamedInt(102)},
+       ScanfTest{"%o", "0146\n", &renamedIntVal, renamedInt(102)},
        ScanfTest{"%v", "103\n", &renamedUintVal, renamedUint(103)},
        ScanfTest{"%d", "104\n", &renamedUintVal, renamedUint(104)},
        ScanfTest{"%d", "105\n", &renamedInt8Val, renamedInt8(105)},
        ScanfTest{"%d", "106\n", &renamedInt16Val, renamedInt16(106)},
        ScanfTest{"%d", "107\n", &renamedInt32Val, renamedInt32(107)},
        ScanfTest{"%d", "108\n", &renamedInt64Val, renamedInt64(108)},
-       ScanfTest{"%d", "109\n", &renamedUint8Val, renamedUint8(109)},
-       ScanfTest{"%d", "110\n", &renamedUint16Val, renamedUint16(110)},
+       ScanfTest{"%x", "6D\n", &renamedUint8Val, renamedUint8(109)},
+       ScanfTest{"%o", "0156\n", &renamedUint16Val, renamedUint16(110)},
        ScanfTest{"%d", "111\n", &renamedUint32Val, renamedUint32(111)},
        ScanfTest{"%d", "112\n", &renamedUint64Val, renamedUint64(112)},
        ScanfTest{"%d", "113\n", &renamedUintptrVal, renamedUintptr(113)},