]> Cypherpunks repositories - gostls13.git/commitdiff
fmt.Scan, fmt.Scanln: Start of a simple scanning API in the fmt package.
authorRob Pike <r@golang.org>
Wed, 26 May 2010 04:02:35 +0000 (21:02 -0700)
committerRob Pike <r@golang.org>
Wed, 26 May 2010 04:02:35 +0000 (21:02 -0700)
Still to do:
- composite types
- user-defined scanners
- format-driven scanning
The package comment will be updated when more of the functionality is in place.

R=rsc
CC=golang-dev
https://golang.org/cl/1252045

src/pkg/fmt/Makefile
src/pkg/fmt/scan.go [new file with mode: 0644]
src/pkg/fmt/scan_test.go [new file with mode: 0644]

index 757af41bb44160850b7ac485eeacc22c6dca5694..28ea396c7526be2aaa9310ffb53a3cc0ade15e0e 100644 (file)
@@ -8,5 +8,6 @@ TARG=fmt
 GOFILES=\
        format.go\
        print.go\
+       scan.go\
 
 include ../../Make.pkg
diff --git a/src/pkg/fmt/scan.go b/src/pkg/fmt/scan.go
new file mode 100644 (file)
index 0000000..42469b9
--- /dev/null
@@ -0,0 +1,413 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt
+
+import (
+       "bytes"
+       "io"
+       "os"
+       "reflect"
+       "strconv"
+       "unicode"
+       "utf8"
+)
+
+// readRuner is the interface to something that can read runes.  If
+// the object provided to Scan does not satisfy this interface, the
+// object will be wrapped by a readRune object.
+type readRuner interface {
+       ReadRune() (rune int, size int, err os.Error)
+}
+
+type ss struct {
+       rr        readRuner    // where to read input
+       buf       bytes.Buffer // token accumulator
+       nlIsSpace bool         // whether newline counts as white space
+       peekRune  int          // one-rune lookahead
+       err       os.Error
+}
+
+// readRune is a structure to enable reading UTF-8 encoded code points
+// from an io.Reader.  It is used if the Reader given to the scanner does
+// not already implement readRuner.
+// TODO: readByteRune for things that can read bytes.
+type readRune struct {
+       reader io.Reader
+       buf    [utf8.UTFMax]byte
+}
+
+// ReadRune returns the next UTF-8 encoded code point from the
+// io.Reader inside r.
+func (r readRune) ReadRune() (rune int, size int, err os.Error) {
+       _, err = r.reader.Read(r.buf[0:1])
+       if err != nil {
+               return 0, 0, err
+       }
+       if r.buf[0] < utf8.RuneSelf { // fast check for common ASCII case
+               rune = int(r.buf[0])
+               return
+       }
+       for size := 1; size < utf8.UTFMax; size++ {
+               _, err = r.reader.Read(r.buf[size : size+1])
+               if err != nil {
+                       break
+               }
+               if !utf8.FullRune(&r.buf) {
+                       continue
+               }
+               if c, w := utf8.DecodeRune(r.buf[0:size]); w == size {
+                       rune = c
+                       return
+               }
+       }
+       return utf8.RuneError, 1, err
+}
+
+
+// A leaky bucket of reusable ss structures.
+var ssFree = make(chan *ss, 100)
+
+// Allocate a new ss struct.  Probably can grab the previous one from ssFree.
+func newScanState(r io.Reader, nlIsSpace bool) *ss {
+       s, ok := <-ssFree
+       if !ok {
+               s = new(ss)
+       }
+       if rr, ok := r.(readRuner); ok {
+               s.rr = rr
+       } else {
+               s.rr = readRune{reader: r}
+       }
+       s.nlIsSpace = nlIsSpace
+       s.peekRune = -1
+       s.err = nil
+       return s
+}
+
+// Save used ss structs in ssFree; avoid an allocation per invocation.
+func (s *ss) free() {
+       // Don't hold on to ss structs with large buffers.
+       if cap(s.buf.Bytes()) > 1024 {
+               return
+       }
+       s.buf.Reset()
+       s.rr = nil
+       _ = ssFree <- s
+}
+
+// readRune reads the next rune, but checks the peeked item first.
+func (s *ss) readRune() (rune int, err os.Error) {
+       if s.peekRune >= 0 {
+               rune = s.peekRune
+               s.peekRune = -1
+               return
+       }
+       rune, _, err = s.rr.ReadRune()
+       return
+}
+
+// token returns the next space-delimited string from the input.
+// For Scanln, it stops at newlines.  For Scan, newlines are treated as
+// spaces.
+func (s *ss) token() string {
+       s.buf.Reset()
+       // skip white space and maybe newline
+       for {
+               rune, err := s.readRune()
+               if err != nil {
+                       s.err = err
+                       return ""
+               }
+               if rune == '\n' {
+                       if s.nlIsSpace {
+                               continue
+                       }
+                       s.err = os.ErrorString("unexpected newline")
+                       return ""
+               }
+               if !unicode.IsSpace(rune) {
+                       s.buf.WriteRune(rune)
+                       break
+               }
+       }
+       // read until white space or newline
+       for {
+               rune, err := s.readRune()
+               if err != nil {
+                       if err == os.EOF {
+                               break
+                       }
+                       s.err = err
+                       return ""
+               }
+               if unicode.IsSpace(rune) {
+                       s.peekRune = rune
+                       break
+               }
+               s.buf.WriteRune(rune)
+       }
+       return s.buf.String()
+}
+
+// Scan parses text read from r, storing successive space-separated
+// values into successive arguments.  Newlines count as space.  Each
+// argument must be a pointer to a basic type.  It returns the number of
+// items successfully parsed.  If that is less than the number of arguments,
+// err will report why.
+func Scan(r io.Reader, a ...interface{}) (n int, err os.Error) {
+       s := newScanState(r, true)
+       n = s.doScan(a)
+       err = s.err
+       s.free()
+       return
+}
+
+// Scanln parses text read from r, storing successive space-separated
+// values into successive arguments.  Scanning stops at a newline and after
+// the final item there must be a newline or EOF.  Each argument must be a
+// pointer to a basic type.  It returns the number of items successfully
+// parsed.  If that is less than the number of arguments, err will report
+// why.
+func Scanln(r io.Reader, a ...interface{}) (n int, err os.Error) {
+       s := newScanState(r, false)
+       n = s.doScan(a)
+       err = s.err
+       s.free()
+       return
+}
+
+var intBits = uint(reflect.Typeof(int(0)).Size() * 8)
+var uintptrBits = uint(reflect.Typeof(int(0)).Size() * 8)
+var complexError = os.ErrorString("syntax error scanning complex number")
+
+// scanBool converts the token to a boolean value.
+func (s *ss) scanBool(tok string) bool {
+       if s.err != nil {
+               return false
+       }
+       var b bool
+       b, s.err = strconv.Atob(tok)
+       return b
+}
+
+// complexParts returns the strings representing the real and imaginary parts of the string.
+func (s *ss) complexParts(str string) (real, imag string) {
+       if len(str) > 2 && str[0] == '(' && str[len(str)-1] == ')' {
+               str = str[1 : len(str)-1]
+       }
+       real, str = floatPart(str)
+       // Must now have a sign.
+       if len(str) == 0 || (str[0] != '+' && str[0] != '-') {
+               s.err = complexError
+               return "", ""
+       }
+       imag, str = floatPart(str)
+       if str != "i" {
+               s.err = complexError
+               return "", ""
+       }
+       return real, imag
+}
+
+// floatPart returns strings holding the floating point value in the string, followed
+// by the remainder of the string.  That is, it splits str into (number,rest-of-string).
+func floatPart(str string) (first, last string) {
+       i := 0
+       // leading sign?
+       if len(str) > 0 && (str[0] == '+' || str[0] == '-') {
+               i++
+       }
+       // digits?
+       for len(str) > 0 && '0' <= str[i] && str[i] <= '9' {
+               i++
+       }
+       // period?
+       if str[i] == '.' {
+               i++
+       }
+       // fraction?
+       for len(str) > 0 && '0' <= str[i] && str[i] <= '9' {
+               i++
+       }
+       // exponent?
+       if len(str) > 0 && (str[i] == 'e' || str[i] == 'E') {
+               i++
+               // leading sign?
+               if str[0] == '+' || str[0] == '-' {
+                       i++
+               }
+               // digits?
+               for len(str) > 0 && '0' <= str[i] && str[i] <= '9' {
+                       i++
+               }
+       }
+       return str[0:i], str[i:]
+}
+
+// scanFloat converts the string to a float value.
+func (s *ss) scanFloat(str string) float64 {
+       var f float
+       f, s.err = strconv.Atof(str)
+       return float64(f)
+}
+
+// scanFloat32 converts the string to a float32 value.
+func (s *ss) scanFloat32(str string) float64 {
+       var f float32
+       f, s.err = strconv.Atof32(str)
+       return float64(f)
+}
+
+// scanFloat64 converts the string to a float64 value.
+func (s *ss) scanFloat64(str string) float64 {
+       var f float64
+       f, s.err = strconv.Atof64(str)
+       return f
+}
+
+// scanComplex converts the token to a complex128 value.
+// The atof argument is a type-specific reader for the underlying type.
+// If we're reading complex64, atof will parse float32s and convert them
+// to float64's to avoid reproducing this code for each complex type.
+func (s *ss) scanComplex(tok string, atof func(*ss, string) float64) complex128 {
+       if s.err != nil {
+               return 0
+       }
+       sreal, simag := s.complexParts(tok)
+       if s.err != nil {
+               return 0
+       }
+       var real, imag float64
+       real = atof(s, sreal)
+       if s.err != nil {
+               return 0
+       }
+       imag = atof(s, simag)
+       if s.err != nil {
+               return 0
+       }
+       return cmplx(real, imag)
+}
+
+// scanInt converts the token to an int64, but checks that it fits into the
+// specified number of bits.
+func (s *ss) scanInt(tok string, bitSize uint) int64 {
+       if s.err != nil {
+               return 0
+       }
+       var i int64
+       i, s.err = strconv.Atoi64(tok)
+       x := (i << (64 - bitSize)) >> (64 - bitSize)
+       if i != x {
+               s.err = os.ErrorString("integer overflow on token " + tok)
+       }
+       return i
+}
+
+// scanUint converts the token to a uint64, but checks that it fits into the
+// specified number of bits.
+func (s *ss) scanUint(tok string, bitSize uint) uint64 {
+       if s.err != nil {
+               return 0
+       }
+       var i uint64
+       i, s.err = strconv.Atoui64(tok)
+       x := (i << (64 - bitSize)) >> (64 - bitSize)
+       if i != x {
+               s.err = os.ErrorString("unsigned integer overflow on token " + tok)
+       }
+       return i
+}
+
+// doScan does the real work.  At the moment, it handles only pointers to basic types.
+func (s *ss) doScan(a []interface{}) int {
+       for n, param := range a {
+               tok := s.token()
+               switch v := param.(type) {
+               case *bool:
+                       *v = s.scanBool(tok)
+               case *complex:
+                       *v = complex(s.scanComplex(tok, (*ss).scanFloat))
+               case *complex64:
+                       *v = complex64(s.scanComplex(tok, (*ss).scanFloat32))
+               case *complex128:
+                       *v = s.scanComplex(tok, (*ss).scanFloat64)
+               case *int:
+                       *v = int(s.scanInt(tok, intBits))
+               case *int8:
+                       *v = int8(s.scanInt(tok, 8))
+               case *int16:
+                       *v = int16(s.scanInt(tok, 16))
+               case *int32:
+                       *v = int32(s.scanInt(tok, 32))
+               case *int64:
+                       *v = s.scanInt(tok, 64)
+               case *uint:
+                       *v = uint(s.scanUint(tok, intBits))
+               case *uint8:
+                       *v = uint8(s.scanUint(tok, 8))
+               case *uint16:
+                       *v = uint16(s.scanUint(tok, 16))
+               case *uint32:
+                       *v = uint32(s.scanUint(tok, 32))
+               case *uint64:
+                       *v = s.scanUint(tok, 64)
+               case *uintptr:
+                       *v = uintptr(s.scanUint(tok, uintptrBits))
+               case *float:
+                       if s.err == nil {
+                               *v, s.err = strconv.Atof(tok)
+                       } else {
+                               *v = 0
+                       }
+               case *float32:
+                       if s.err == nil {
+                               *v, s.err = strconv.Atof32(tok)
+                       } else {
+                               *v = 0
+                       }
+               case *float64:
+                       if s.err == nil {
+                               *v, s.err = strconv.Atof64(tok)
+                       } else {
+                               *v = 0
+                       }
+               case *string:
+                       *v = tok
+               default:
+                       t := reflect.Typeof(v)
+                       str := t.String()
+                       if _, ok := t.(*reflect.PtrType); !ok {
+                               s.err = os.ErrorString("Scan: type not a pointer: " + str)
+                       } else {
+                               s.err = os.ErrorString("Scan: can't handle type: " + str)
+                       }
+               }
+               if s.err != nil {
+                       return n
+               }
+       }
+       // Check for newline if required.
+       if !s.nlIsSpace {
+               for {
+                       rune, err := s.readRune()
+                       if err != nil {
+                               if err == os.EOF {
+                                       break
+                               }
+                               s.err = err
+                               break
+                       }
+                       if rune == '\n' {
+                               break
+                       }
+                       if !unicode.IsSpace(rune) {
+                               s.err = os.ErrorString("Scan: expected newline")
+                               break
+                       }
+               }
+       }
+       return len(a)
+}
diff --git a/src/pkg/fmt/scan_test.go b/src/pkg/fmt/scan_test.go
new file mode 100644 (file)
index 0000000..a49fb90
--- /dev/null
@@ -0,0 +1,181 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fmt_test
+
+import (
+       . "fmt"
+       "io"
+       "os"
+       "reflect"
+       "strings"
+       "testing"
+)
+
+type ScanTest struct {
+       text string
+       in   interface{}
+       out  interface{}
+}
+
+var boolVal bool
+var intVal int
+var int8Val int8
+var int16Val int16
+var int32Val int32
+var int64Val int64
+var uintVal uint
+var uint8Val uint8
+var uint16Val uint16
+var uint32Val uint32
+var uint64Val uint64
+var floatVal float
+var float32Val float32
+var float64Val float64
+var stringVal string
+var complexVal complex
+var complex64Val complex64
+var complex128Val complex128
+
+var scanTests = []ScanTest{
+       ScanTest{"T\n", &boolVal, true},
+       ScanTest{"21\n", &intVal, 21},
+       ScanTest{"22\n", &int8Val, int8(22)},
+       ScanTest{"23\n", &int16Val, int16(23)},
+       ScanTest{"24\n", &int32Val, int32(24)},
+       ScanTest{"25\n", &int64Val, int64(25)},
+       ScanTest{"127\n", &int8Val, int8(127)},
+       ScanTest{"-21\n", &intVal, -21},
+       ScanTest{"-22\n", &int8Val, int8(-22)},
+       ScanTest{"-23\n", &int16Val, int16(-23)},
+       ScanTest{"-24\n", &int32Val, int32(-24)},
+       ScanTest{"-25\n", &int64Val, int64(-25)},
+       ScanTest{"-128\n", &int8Val, int8(-128)},
+       ScanTest{"+21\n", &intVal, +21},
+       ScanTest{"+22\n", &int8Val, int8(+22)},
+       ScanTest{"+23\n", &int16Val, int16(+23)},
+       ScanTest{"+24\n", &int32Val, int32(+24)},
+       ScanTest{"+25\n", &int64Val, int64(+25)},
+       ScanTest{"+127\n", &int8Val, int8(+127)},
+       ScanTest{"26\n", &uintVal, uint(26)},
+       ScanTest{"27\n", &uint8Val, uint8(27)},
+       ScanTest{"28\n", &uint16Val, uint16(28)},
+       ScanTest{"29\n", &uint32Val, uint32(29)},
+       ScanTest{"30\n", &uint64Val, uint64(30)},
+       ScanTest{"255\n", &uint8Val, uint8(255)},
+       ScanTest{"32767\n", &int16Val, int16(32767)},
+       ScanTest{"2.3\n", &floatVal, 2.3},
+       ScanTest{"2.3e1\n", &float32Val, float32(2.3e1)},
+       ScanTest{"2.3e2\n", &float64Val, float64(2.3e2)},
+       ScanTest{"2.35\n", &stringVal, "2.35"},
+       ScanTest{"(3.4e1-2i)\n", &complexVal, 3.4e1 - 2i},
+       ScanTest{"-3.45e1-3i\n", &complex64Val, complex64(-3.45e1 - 3i)},
+       ScanTest{"-.45e1-1e2i\n", &complex128Val, complex128(-.45e1 - 100i)},
+}
+
+var overflowTests = []ScanTest{
+       ScanTest{"128", &int8Val, 0},
+       ScanTest{"32768", &int16Val, 0},
+       ScanTest{"-129", &int8Val, 0},
+       ScanTest{"-32769", &int16Val, 0},
+       ScanTest{"256", &uint8Val, 0},
+       ScanTest{"65536", &uint16Val, 0},
+       ScanTest{"1e100", &float32Val, 0},
+       ScanTest{"1e500", &float64Val, 0},
+       ScanTest{"(1e100+0i)", &complexVal, 0},
+       ScanTest{"(1+1e100i)", &complex64Val, 0},
+       ScanTest{"(1-1e500i)", &complex128Val, 0},
+}
+
+func testScan(t *testing.T, scan func(r io.Reader, a ...interface{}) (int, os.Error)) {
+       for _, test := range scanTests {
+               r := strings.NewReader(test.text)
+               n, err := scan(r, test.in)
+               if err != nil {
+                       t.Errorf("got error scanning %q: %s", test.text, err)
+                       continue
+               }
+               if n != 1 {
+                       t.Errorf("count error on entry %q: got %d", test.text, n)
+                       continue
+               }
+               // The incoming value may be a pointer
+               v := reflect.NewValue(test.in)
+               if p, ok := v.(*reflect.PtrValue); ok {
+                       v = p.Elem()
+               }
+               val := v.Interface()
+               if !reflect.DeepEqual(val, test.out) {
+                       t.Errorf("scanning %q: expected %v got %v, type %T", test.text, test.out, val, val)
+               }
+       }
+}
+
+func TestScan(t *testing.T) {
+       testScan(t, Scan)
+}
+
+func TestScanln(t *testing.T) {
+       testScan(t, Scanln)
+}
+
+func TestScanOverflow(t *testing.T) {
+       for _, test := range overflowTests {
+               r := strings.NewReader(test.text)
+               _, err := Scan(r, test.in)
+               if err == nil {
+                       t.Errorf("expected overflow scanning %q", test.text)
+                       continue
+               }
+               if strings.Index(err.String(), "overflow") < 0 && strings.Index(err.String(), "too large") < 0 {
+                       t.Errorf("expected overflow error scanning %q: %s", test.text, err)
+               }
+       }
+}
+
+func TestScanMultiple(t *testing.T) {
+       text := "1 2 3 x"
+       r := strings.NewReader(text)
+       var a, b, c, d int
+       n, err := Scan(r, &a, &b, &c, &d)
+       if n != 3 {
+               t.Errorf("count error: expected 3: got %d", n)
+       }
+       if err == nil {
+               t.Errorf("expected error scanning ", text)
+       }
+}
+
+func TestScanNotPointer(t *testing.T) {
+       r := strings.NewReader("1")
+       var a int
+       _, err := Scan(r, a)
+       if err == nil {
+               t.Error("expected error scanning non-pointer")
+       } else if strings.Index(err.String(), "pointer") < 0 {
+               t.Errorf("expected pointer error scanning non-pointer, got: %s", err)
+       }
+}
+
+func TestScanlnNoNewline(t *testing.T) {
+       r := strings.NewReader("1 x\n")
+       var a int
+       _, err := Scanln(r, &a)
+       if err == nil {
+               t.Error("expected error scanning string missing newline")
+       } else if strings.Index(err.String(), "newline") < 0 {
+               t.Errorf("expected newline error scanning string missing newline, got: %s", err)
+       }
+}
+
+func TestScanlnWithMiddleNewline(t *testing.T) {
+       r := strings.NewReader("123\n456\n")
+       var a, b int
+       _, err := Scanln(r, &a, &b)
+       if err == nil {
+               t.Error("expected error scanning string with extra newline")
+       } else if strings.Index(err.String(), "newline") < 0 {
+               t.Errorf("expected newline error scanning string with extra newline, got: %s", err)
+       }
+}