]> Cypherpunks repositories - gostls13.git/commitdiff
net: add Buffers type, do writev on unix
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 27 Sep 2016 20:50:57 +0000 (20:50 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 29 Sep 2016 20:33:45 +0000 (20:33 +0000)
No fast path currently for solaris, windows, nacl, plan9.

Fixes #13451

Change-Id: I24b3233a2e3a57fc6445e276a5c0d7b097884007
Reviewed-on: https://go-review.googlesource.com/29951
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/fd_unix.go
src/net/net.go
src/net/net_test.go
src/net/writev_test.go [new file with mode: 0644]
src/net/writev_unix.go [new file with mode: 0644]
src/syscall/syscall_nacl.go

index 11dde76977e40e2f52d476cd3ce19d9465b27224..1296bc56b2021eb57dd37ca93a8df45cb264b80c 100644 (file)
@@ -29,6 +29,9 @@ type netFD struct {
        laddr       Addr
        raddr       Addr
 
+       // writev cache.
+       iovecs *[]syscall.Iovec
+
        // wait server
        pd pollDesc
 }
index d6812d1ef054dc03e57b9f53fa7689da924c46dd..8ab952ae72b7846dd28194d07458dbe5cdfd8dd0 100644 (file)
@@ -604,3 +604,66 @@ func acquireThread() {
 func releaseThread() {
        <-threadLimit
 }
+
+// buffersWriter is the interface implemented by Conns that support a
+// "writev"-like batch write optimization.
+// writeBuffers should fully consume and write all chunks from the
+// provided Buffers, else it should report a non-nil error.
+type buffersWriter interface {
+       writeBuffers(*Buffers) (int64, error)
+}
+
+var testHookDidWritev = func(wrote int) {}
+
+// Buffers contains zero or more runs of bytes to write.
+//
+// On certain machines, for certain types of connections, this is
+// optimized into an OS-specific batch write operation (such as
+// "writev").
+type Buffers [][]byte
+
+var (
+       _ io.WriterTo = (*Buffers)(nil)
+       _ io.Reader   = (*Buffers)(nil)
+)
+
+func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) {
+       if wv, ok := w.(buffersWriter); ok {
+               return wv.writeBuffers(v)
+       }
+       for _, b := range *v {
+               nb, err := w.Write(b)
+               n += int64(nb)
+               if err != nil {
+                       v.consume(n)
+                       return n, err
+               }
+       }
+       v.consume(n)
+       return n, nil
+}
+
+func (v *Buffers) Read(p []byte) (n int, err error) {
+       for len(p) > 0 && len(*v) > 0 {
+               n0 := copy(p, (*v)[0])
+               v.consume(int64(n0))
+               p = p[n0:]
+               n += n0
+       }
+       if len(*v) == 0 {
+               err = io.EOF
+       }
+       return
+}
+
+func (v *Buffers) consume(n int64) {
+       for len(*v) > 0 {
+               ln0 := int64(len((*v)[0]))
+               if ln0 > n {
+                       (*v)[0] = (*v)[0][n:]
+                       return
+               }
+               n -= ln0
+               *v = (*v)[1:]
+       }
+}
index b2f825daffc93c8f97fc0efaa917d4a0779ed27d..1968ff323e5e612e66ac961bcdc29d50e42b0151 100644 (file)
@@ -414,3 +414,38 @@ func TestZeroByteRead(t *testing.T) {
                }
        }
 }
+
+// withTCPConnPair sets up a TCP connection between two peers, then
+// runs peer1 and peer2 concurrently. withTCPConnPair returns when
+// both have completed.
+func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
+       ln, err := newLocalListener("tcp")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer ln.Close()
+       errc := make(chan error, 2)
+       go func() {
+               c1, err := ln.Accept()
+               if err != nil {
+                       errc <- err
+                       return
+               }
+               defer c1.Close()
+               errc <- peer1(c1.(*TCPConn))
+       }()
+       go func() {
+               c2, err := Dial("tcp", ln.Addr().String())
+               if err != nil {
+                       errc <- err
+                       return
+               }
+               defer c2.Close()
+               errc <- peer2(c2.(*TCPConn))
+       }()
+       for i := 0; i < 2; i++ {
+               if err := <-errc; err != nil {
+                       t.Fatal(err)
+               }
+       }
+}
diff --git a/src/net/writev_test.go b/src/net/writev_test.go
new file mode 100644 (file)
index 0000000..cc53adc
--- /dev/null
@@ -0,0 +1,171 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "bytes"
+       "fmt"
+       "io"
+       "io/ioutil"
+       "reflect"
+       "runtime"
+       "sync"
+       "testing"
+)
+
+func TestBuffers_read(t *testing.T) {
+       const story = "once upon a time in Gopherland ... "
+       buffers := Buffers{
+               []byte("once "),
+               []byte("upon "),
+               []byte("a "),
+               []byte("time "),
+               []byte("in "),
+               []byte("Gopherland ... "),
+       }
+       got, err := ioutil.ReadAll(&buffers)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if string(got) != story {
+               t.Errorf("read %q; want %q", got, story)
+       }
+       if len(buffers) != 0 {
+               t.Errorf("len(buffers) = %d; want 0", len(buffers))
+       }
+}
+
+func TestBuffers_consume(t *testing.T) {
+       tests := []struct {
+               in      Buffers
+               consume int64
+               want    Buffers
+       }{
+               {
+                       in:      Buffers{[]byte("foo"), []byte("bar")},
+                       consume: 0,
+                       want:    Buffers{[]byte("foo"), []byte("bar")},
+               },
+               {
+                       in:      Buffers{[]byte("foo"), []byte("bar")},
+                       consume: 2,
+                       want:    Buffers{[]byte("o"), []byte("bar")},
+               },
+               {
+                       in:      Buffers{[]byte("foo"), []byte("bar")},
+                       consume: 3,
+                       want:    Buffers{[]byte("bar")},
+               },
+               {
+                       in:      Buffers{[]byte("foo"), []byte("bar")},
+                       consume: 4,
+                       want:    Buffers{[]byte("ar")},
+               },
+               {
+                       in:      Buffers{nil, nil, nil, []byte("bar")},
+                       consume: 1,
+                       want:    Buffers{[]byte("ar")},
+               },
+               {
+                       in:      Buffers{nil, nil, nil, []byte("foo")},
+                       consume: 0,
+                       want:    Buffers{[]byte("foo")},
+               },
+               {
+                       in:      Buffers{nil, nil, nil},
+                       consume: 0,
+                       want:    Buffers{},
+               },
+       }
+       for i, tt := range tests {
+               in := tt.in
+               in.consume(tt.consume)
+               if !reflect.DeepEqual(in, tt.want) {
+                       t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
+               }
+       }
+}
+
+func TestBuffers_WriteTo(t *testing.T) {
+       for _, name := range []string{"WriteTo", "Copy"} {
+               for _, size := range []int{0, 10, 1023, 1024, 1025} {
+                       t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
+                               testBuffer_writeTo(t, size, name == "Copy")
+                       })
+               }
+       }
+}
+
+func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
+       oldHook := testHookDidWritev
+       defer func() { testHookDidWritev = oldHook }()
+       var writeLog struct {
+               sync.Mutex
+               log []int
+       }
+       testHookDidWritev = func(size int) {
+               writeLog.Lock()
+               writeLog.log = append(writeLog.log, size)
+               writeLog.Unlock()
+       }
+       var want bytes.Buffer
+       for i := 0; i < chunks; i++ {
+               want.WriteByte(byte(i))
+       }
+
+       withTCPConnPair(t, func(c *TCPConn) error {
+               buffers := make(Buffers, chunks)
+               for i := range buffers {
+                       buffers[i] = want.Bytes()[i : i+1]
+               }
+               var n int64
+               var err error
+               if useCopy {
+                       n, err = io.Copy(c, &buffers)
+               } else {
+                       n, err = buffers.WriteTo(c)
+               }
+               if err != nil {
+                       return err
+               }
+               if len(buffers) != 0 {
+                       return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
+               }
+               if n != int64(want.Len()) {
+                       return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
+               }
+               return nil
+       }, func(c *TCPConn) error {
+               all, err := ioutil.ReadAll(c)
+               if !bytes.Equal(all, want.Bytes()) || err != nil {
+                       return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
+               }
+
+               writeLog.Lock() // no need to unlock
+               var gotSum int
+               for _, v := range writeLog.log {
+                       gotSum += v
+               }
+
+               var wantSum int
+               var wantMinCalls int
+               switch runtime.GOOS {
+               case "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd":
+                       wantSum = want.Len()
+                       v := chunks
+                       for v > 0 {
+                               wantMinCalls++
+                               v -= 1024
+                       }
+               }
+               if len(writeLog.log) < wantMinCalls {
+                       t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
+               }
+               if gotSum != wantSum {
+                       t.Errorf("writev call sum  = %v; want %v", gotSum, wantSum)
+               }
+               return nil
+       })
+}
diff --git a/src/net/writev_unix.go b/src/net/writev_unix.go
new file mode 100644 (file)
index 0000000..ac4f7cf
--- /dev/null
@@ -0,0 +1,91 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd linux netbsd openbsd
+
+package net
+
+import (
+       "io"
+       "os"
+       "syscall"
+       "unsafe"
+)
+
+func (c *conn) writeBuffers(v *Buffers) (int64, error) {
+       if !c.ok() {
+               return 0, syscall.EINVAL
+       }
+       n, err := c.fd.writeBuffers(v)
+       if err != nil {
+               return n, &OpError{Op: "writev", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+       return n, nil
+}
+
+func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) {
+       if err := fd.writeLock(); err != nil {
+               return 0, err
+       }
+       defer fd.writeUnlock()
+       if err := fd.pd.prepareWrite(); err != nil {
+               return 0, err
+       }
+
+       var iovecs []syscall.Iovec
+       if fd.iovecs != nil {
+               iovecs = *fd.iovecs
+       }
+       // TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is
+       // 1024 and this seems conservative enough for now. Darwin's
+       // UIO_MAXIOV also seems to be 1024.
+       maxVec := 1024
+
+       for len(*v) > 0 {
+               iovecs = iovecs[:0]
+               for _, chunk := range *v {
+                       if len(chunk) == 0 {
+                               continue
+                       }
+                       iovecs = append(iovecs, syscall.Iovec{Base: &chunk[0]})
+                       iovecs[len(iovecs)-1].SetLen(len(chunk))
+                       if len(iovecs) == maxVec {
+                               break
+                       }
+               }
+               if len(iovecs) == 0 {
+                       break
+               }
+               fd.iovecs = &iovecs // cache
+
+               wrote, _, e0 := syscall.Syscall(syscall.SYS_WRITEV,
+                       uintptr(fd.sysfd),
+                       uintptr(unsafe.Pointer(&iovecs[0])),
+                       uintptr(len(iovecs)))
+               if wrote < 0 {
+                       wrote = 0
+               }
+               testHookDidWritev(int(wrote))
+               n += int64(wrote)
+               v.consume(int64(wrote))
+               if e0 == syscall.EAGAIN {
+                       if err = fd.pd.waitWrite(); err == nil {
+                               continue
+                       }
+               } else if e0 != 0 {
+                       err = syscall.Errno(e0)
+               }
+               if err != nil {
+                       break
+               }
+               if n == 0 {
+                       err = io.ErrUnexpectedEOF
+                       break
+               }
+       }
+       if _, ok := err.(syscall.Errno); ok {
+               err = os.NewSyscallError("writev", err)
+       }
+       return n, err
+}
index 88e4e3a9dc3f760873f42334e927457adf115ff9..3247505288e6df3285977d5933b5e0910a193a78 100644 (file)
@@ -296,3 +296,5 @@ func RouteRIB(facility, param int) ([]byte, error)                { return nil,
 func ParseRoutingMessage(b []byte) ([]RoutingMessage, error)      { return nil, ENOSYS }
 func ParseRoutingSockaddr(msg RoutingMessage) ([]Sockaddr, error) { return nil, ENOSYS }
 func SysctlUint32(name string) (value uint32, err error)          { return 0, ENOSYS }
+
+type Iovec struct{} // dummy