]> Cypherpunks repositories - gostls13.git/commitdiff
os: make readConsole handle its input and output correctly
authorAlex Brainman <alex.brainman@gmail.com>
Wed, 21 Sep 2016 01:19:36 +0000 (11:19 +1000)
committerAlex Brainman <alex.brainman@gmail.com>
Thu, 13 Oct 2016 06:16:53 +0000 (06:16 +0000)
This CL introduces first test for readConsole. And new test
discovered couple of problems with readConsole.

Console characters consist of multiple bytes each, but byte blocks
returned by syscall.ReadFile have no character boundaries. Some
multi-byte characters might start at the end of one block, and end
at the start of next block. readConsole feeds these blocks to
syscall.MultiByteToWideChar to convert them into utf16, but if some
multi-byte characters have no ending or starting bytes, the
syscall.MultiByteToWideChar might get confused. Current version of
syscall.MultiByteToWideChar call will make
syscall.MultiByteToWideChar ignore all these not complete
multi-byte characters.

The CL solves this issue by changing processing from "randomly
sized block of bytes at a time" to "one multi-byte character at a
time". New readConsole code calls syscall.ReadFile to get 1 byte
first. Then it feeds this byte to syscall.MultiByteToWideChar.
The new syscall.MultiByteToWideChar call uses MB_ERR_INVALID_CHARS
flag to make syscall.MultiByteToWideChar return error if input is
not complete character. If syscall.MultiByteToWideChar returns
correspondent error, we read another byte and pass 2 byte buffer
into syscall.MultiByteToWideChar, and so on until success.

Old readConsole code would also sometimes return no data if user
buffer was smaller then uint16 size, which would confuse callers
that supply 1 byte buffer. This CL fixes that problem too.

Fixes #17097

Change-Id: I88136cdf6a7bf3aed5fbb9ad2c759b6c0304ce30
Reviewed-on: https://go-review.googlesource.com/29493
Run-TryBot: Alex Brainman <alex.brainman@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
src/internal/syscall/windows/syscall_windows.go
src/os/export_windows_test.go [new file with mode: 0644]
src/os/file_windows.go
src/os/os_windows_test.go

index 77d6033a353d62ca97cc00cab115389484de8e5a..c4e59b28bd272aae635f5e51e633312b6364469c 100644 (file)
@@ -6,6 +6,10 @@ package windows
 
 import "syscall"
 
+const (
+       ERROR_NO_UNICODE_TRANSLATION syscall.Errno = 1113
+)
+
 const GAA_FLAG_INCLUDE_PREFIX = 0x00000010
 
 const (
@@ -137,6 +141,8 @@ func Rename(oldpath, newpath string) error {
        return MoveFileEx(from, to, MOVEFILE_REPLACE_EXISTING)
 }
 
+const MB_ERR_INVALID_CHARS = 8
+
 //sys  GetACP() (acp uint32) = kernel32.GetACP
 //sys  GetConsoleCP() (ccp uint32) = kernel32.GetConsoleCP
 //sys  MultiByteToWideChar(codePage uint32, dwFlags uint32, str *byte, nstr int32, wchar *uint16, nwchar int32) (nwrite int32, err error) = kernel32.MultiByteToWideChar
diff --git a/src/os/export_windows_test.go b/src/os/export_windows_test.go
new file mode 100644 (file)
index 0000000..fbfb6b0
--- /dev/null
@@ -0,0 +1,14 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package os
+
+// Export for testing.
+
+var (
+       NewConsoleFile                    = newConsoleFile
+       GetCPP                            = &getCP
+       ReadFileP                         = &readFile
+       ResetGetConsoleCPAndReadFileFuncs = resetGetConsoleCPAndReadFileFuncs
+)
index efbf0e85fbbb21239a57f58fabb53042e7c5c661..ed06b55535886f939377d895c13e5de354f256e7 100644 (file)
@@ -5,6 +5,7 @@
 package os
 
 import (
+       "errors"
        "internal/syscall/windows"
        "io"
        "runtime"
@@ -28,7 +29,7 @@ type file struct {
        // only for console io
        isConsole bool
        lastbits  []byte // first few bytes of the last incomplete rune in last write
-       readbuf   []rune // input console buffer
+       readbuf   []byte // last few bytes of the last read that did not fit in the user buffer
 }
 
 // Fd returns the Windows handle referencing the open file.
@@ -44,20 +45,28 @@ func (file *File) Fd() uintptr {
 // Unlike NewFile, it does not check that h is syscall.InvalidHandle.
 func newFile(h syscall.Handle, name string) *File {
        f := &File{&file{fd: h, name: name}}
-       var m uint32
-       if syscall.GetConsoleMode(f.fd, &m) == nil {
-               f.isConsole = true
-       }
        runtime.SetFinalizer(f.file, (*file).close)
        return f
 }
 
+// newConsoleFile creates new File that will be used as console.
+func newConsoleFile(h syscall.Handle, name string) *File {
+       f := newFile(h, name)
+       f.isConsole = true
+       f.readbuf = make([]byte, 0, 4)
+       return f
+}
+
 // NewFile returns a new File with the given file descriptor and name.
 func NewFile(fd uintptr, name string) *File {
        h := syscall.Handle(fd)
        if h == syscall.InvalidHandle {
                return nil
        }
+       var m uint32
+       if syscall.GetConsoleMode(h, &m) == nil {
+               return newConsoleFile(h, name)
+       }
        return newFile(h, name)
 }
 
@@ -191,59 +200,101 @@ func (file *file) close() error {
        return err
 }
 
+var (
+       // These variables are used for testing readConsole.
+       getCP    = windows.GetConsoleCP
+       readFile = syscall.ReadFile
+)
+
+func resetGetConsoleCPAndReadFileFuncs() {
+       getCP = windows.GetConsoleCP
+       readFile = syscall.ReadFile
+}
+
+// copyReadConsoleBuffer copies data stored in f.readbuf into buf.
+// It adjusts f.readbuf accordingly and returns number of bytes copied.
+func (f *File) copyReadConsoleBuffer(buf []byte) (n int, err error) {
+       n = copy(buf, f.readbuf)
+       newsize := copy(f.readbuf, f.readbuf[n:])
+       f.readbuf = f.readbuf[:newsize]
+       return n, nil
+}
+
+// readOneUTF16FromConsole reads single character from console,
+// converts it into utf16 and return it to the caller.
+func (f *File) readOneUTF16FromConsole() (uint16, error) {
+       var buf [1]byte
+       mbytes := make([]byte, 0, 4)
+       cp := getCP()
+       for {
+               var nmb uint32
+               err := readFile(f.fd, buf[:], &nmb, nil)
+               if err != nil {
+                       return 0, err
+               }
+               if nmb == 0 {
+                       continue
+               }
+               mbytes = append(mbytes, buf[0])
+
+               // Convert from 8-bit console encoding to UTF16.
+               // MultiByteToWideChar defaults to Unicode NFC form, which is the expected one.
+               nwc, err := windows.MultiByteToWideChar(cp, windows.MB_ERR_INVALID_CHARS, &mbytes[0], int32(len(mbytes)), nil, 0)
+               if err != nil {
+                       if err == windows.ERROR_NO_UNICODE_TRANSLATION {
+                               continue
+                       }
+                       return 0, err
+               }
+               if nwc != 1 {
+                       return 0, errors.New("MultiByteToWideChar returns " + itoa(int(nwc)) + " characters, but only 1 expected")
+               }
+               var wchars [1]uint16
+               nwc, err = windows.MultiByteToWideChar(cp, windows.MB_ERR_INVALID_CHARS, &mbytes[0], int32(len(mbytes)), &wchars[0], nwc)
+               if err != nil {
+                       return 0, err
+               }
+               return wchars[0], nil
+       }
+}
+
 // readConsole reads utf16 characters from console File,
-// encodes them into utf8 and stores them in buffer b.
+// encodes them into utf8 and stores them in buffer buf.
 // It returns the number of utf8 bytes read and an error, if any.
-func (f *File) readConsole(b []byte) (n int, err error) {
-       if len(b) == 0 {
+func (f *File) readConsole(buf []byte) (n int, err error) {
+       if len(buf) == 0 {
                return 0, nil
        }
-       if len(f.readbuf) == 0 {
-               numBytes := len(b)
-               // Windows  can't read bytes over max of int16.
-               // Some versions of Windows can read even less.
-               // See golang.org/issue/13697.
-               if numBytes > 10000 {
-                       numBytes = 10000
-               }
-               mbytes := make([]byte, numBytes)
-               var nmb uint32
-               err := syscall.ReadFile(f.fd, mbytes, &nmb, nil)
+       if len(f.readbuf) > 0 {
+               return f.copyReadConsoleBuffer(buf)
+       }
+       wchar, err := f.readOneUTF16FromConsole()
+       if err != nil {
+               return 0, err
+       }
+       r := rune(wchar)
+       if utf16.IsSurrogate(r) {
+               wchar, err := f.readOneUTF16FromConsole()
                if err != nil {
                        return 0, err
                }
-               if nmb > 0 {
-                       var pmb *byte
-                       if len(b) > 0 {
-                               pmb = &mbytes[0]
-                       }
-                       ccp := windows.GetConsoleCP()
-                       // Convert from 8-bit console encoding to UTF16.
-                       // MultiByteToWideChar defaults to Unicode NFC form, which is the expected one.
-                       nwc, err := windows.MultiByteToWideChar(ccp, 0, pmb, int32(nmb), nil, 0)
-                       if err != nil {
-                               return 0, err
-                       }
-                       wchars := make([]uint16, nwc)
-                       pwc := &wchars[0]
-                       nwc, err = windows.MultiByteToWideChar(ccp, 0, pmb, int32(nmb), pwc, nwc)
-                       if err != nil {
-                               return 0, err
-                       }
-                       f.readbuf = utf16.Decode(wchars[:nwc])
-               }
+               r = utf16.DecodeRune(r, rune(wchar))
        }
-       for i, r := range f.readbuf {
-               if utf8.RuneLen(r) > len(b) {
-                       f.readbuf = f.readbuf[i:]
-                       return n, nil
+       if nr := utf8.RuneLen(r); nr > len(buf) {
+               start := len(f.readbuf)
+               for ; nr > 0; nr-- {
+                       f.readbuf = append(f.readbuf, 0)
                }
-               nr := utf8.EncodeRune(b, r)
-               b = b[nr:]
+               utf8.EncodeRune(f.readbuf[start:cap(f.readbuf)], r)
+       } else {
+               utf8.EncodeRune(buf, r)
+               buf = buf[nr:]
                n += nr
        }
-       f.readbuf = nil
-       return n, nil
+       if n > 0 {
+               return n, nil
+       }
+       return f.copyReadConsoleBuffer(buf)
 }
 
 // read reads up to len(b) bytes from the File.
index acdf4f17a658fb9ad5d282fc36f9d329cf388e41..741df3ff1e97e4188d86f033d8dfab5dcb0a4342 100644 (file)
@@ -5,6 +5,8 @@
 package os_test
 
 import (
+       "bytes"
+       "encoding/hex"
        "fmt"
        "internal/syscall/windows"
        "internal/testenv"
@@ -545,3 +547,86 @@ func TestStatSymlinkLoop(t *testing.T) {
                t.Errorf("expected *PathError with ELOOP, got %T: %v\n", err, err)
        }
 }
+
+func TestReadStdin(t *testing.T) {
+       defer os.ResetGetConsoleCPAndReadFileFuncs()
+
+       testConsole := os.NewConsoleFile(syscall.Stdin, "test")
+
+       var (
+               hiraganaA_CP932 = []byte{0x82, 0xa0}
+               hiraganaA_UTF8  = "\u3042"
+
+               tests = []struct {
+                       cp     uint32
+                       input  []byte
+                       output string // always utf8
+               }{
+                       {
+                               cp:     437,
+                               input:  []byte("abc"),
+                               output: "abc",
+                       },
+                       {
+                               cp:     850,
+                               input:  []byte{0x84, 0x94, 0x81},
+                               output: "äöü",
+                       },
+                       {
+                               cp:     932,
+                               input:  hiraganaA_CP932,
+                               output: hiraganaA_UTF8,
+                       },
+                       {
+                               cp:     932,
+                               input:  bytes.Repeat(hiraganaA_CP932, 2),
+                               output: strings.Repeat(hiraganaA_UTF8, 2),
+                       },
+                       {
+                               cp:     932,
+                               input:  append(bytes.Repeat(hiraganaA_CP932, 3), '.'),
+                               output: strings.Repeat(hiraganaA_UTF8, 3) + ".",
+                       },
+                       {
+                               cp:     932,
+                               input:  append(append([]byte("hello"), hiraganaA_CP932...), []byte("world")...),
+                               output: "hello" + hiraganaA_UTF8 + "world",
+                       },
+                       {
+                               cp:     932,
+                               input:  append(append([]byte("hello"), bytes.Repeat(hiraganaA_CP932, 5)...), []byte("world")...),
+                               output: "hello" + strings.Repeat(hiraganaA_UTF8, 5) + "world",
+                       },
+               }
+       )
+       for _, bufsize := range []int{1, 2, 3, 4, 5, 8, 10, 16, 20, 50, 100} {
+       nextTest:
+               for ti, test := range tests {
+                       input := bytes.NewBuffer(test.input)
+                       *os.ReadFileP = func(h syscall.Handle, buf []byte, done *uint32, o *syscall.Overlapped) error {
+                               n, err := input.Read(buf)
+                               *done = uint32(n)
+                               return err
+                       }
+                       *os.GetCPP = func() uint32 {
+                               return test.cp
+                       }
+                       var bigbuf []byte
+                       for len(bigbuf) < len([]byte(test.output)) {
+                               buf := make([]byte, bufsize)
+                               n, err := testConsole.Read(buf)
+                               if err != nil {
+                                       t.Errorf("test=%d bufsize=%d: read failed: %v", ti, bufsize, err)
+                                       continue nextTest
+                               }
+                               bigbuf = append(bigbuf, buf[:n]...)
+                       }
+                       have := hex.Dump(bigbuf)
+                       expected := hex.Dump([]byte(test.output))
+                       if have != expected {
+                               t.Errorf("test=%d bufsize=%d: %q expected, but %q received", ti, bufsize, expected, have)
+                               continue nextTest
+                       }
+               }
+       }
+}