From: Brad Fitzpatrick Date: Tue, 27 Sep 2016 20:50:57 +0000 (+0000) Subject: net: add Buffers type, do writev on unix X-Git-Tag: go1.8beta1~1108 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=8e69d43b32be578cd36eebe491b6e1205dbf32a4;p=gostls13.git net: add Buffers type, do writev on unix No fast path currently for solaris, windows, nacl, plan9. Fixes #13451 Change-Id: I24b3233a2e3a57fc6445e276a5c0d7b097884007 Reviewed-on: https://go-review.googlesource.com/29951 Reviewed-by: Ian Lance Taylor Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- diff --git a/src/net/fd_unix.go b/src/net/fd_unix.go index 11dde76977..1296bc56b2 100644 --- a/src/net/fd_unix.go +++ b/src/net/fd_unix.go @@ -29,6 +29,9 @@ type netFD struct { laddr Addr raddr Addr + // writev cache. + iovecs *[]syscall.Iovec + // wait server pd pollDesc } diff --git a/src/net/net.go b/src/net/net.go index d6812d1ef0..8ab952ae72 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -604,3 +604,66 @@ func acquireThread() { func releaseThread() { <-threadLimit } + +// buffersWriter is the interface implemented by Conns that support a +// "writev"-like batch write optimization. +// writeBuffers should fully consume and write all chunks from the +// provided Buffers, else it should report a non-nil error. +type buffersWriter interface { + writeBuffers(*Buffers) (int64, error) +} + +var testHookDidWritev = func(wrote int) {} + +// Buffers contains zero or more runs of bytes to write. +// +// On certain machines, for certain types of connections, this is +// optimized into an OS-specific batch write operation (such as +// "writev"). +type Buffers [][]byte + +var ( + _ io.WriterTo = (*Buffers)(nil) + _ io.Reader = (*Buffers)(nil) +) + +func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) { + if wv, ok := w.(buffersWriter); ok { + return wv.writeBuffers(v) + } + for _, b := range *v { + nb, err := w.Write(b) + n += int64(nb) + if err != nil { + v.consume(n) + return n, err + } + } + v.consume(n) + return n, nil +} + +func (v *Buffers) Read(p []byte) (n int, err error) { + for len(p) > 0 && len(*v) > 0 { + n0 := copy(p, (*v)[0]) + v.consume(int64(n0)) + p = p[n0:] + n += n0 + } + if len(*v) == 0 { + err = io.EOF + } + return +} + +func (v *Buffers) consume(n int64) { + for len(*v) > 0 { + ln0 := int64(len((*v)[0])) + if ln0 > n { + (*v)[0] = (*v)[0][n:] + return + } + n -= ln0 + *v = (*v)[1:] + } +} diff --git a/src/net/net_test.go b/src/net/net_test.go index b2f825daff..1968ff323e 100644 --- a/src/net/net_test.go +++ b/src/net/net_test.go @@ -414,3 +414,38 @@ func TestZeroByteRead(t *testing.T) { } } } + +// withTCPConnPair sets up a TCP connection between two peers, then +// runs peer1 and peer2 concurrently. withTCPConnPair returns when +// both have completed. +func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) { + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + errc := make(chan error, 2) + go func() { + c1, err := ln.Accept() + if err != nil { + errc <- err + return + } + defer c1.Close() + errc <- peer1(c1.(*TCPConn)) + }() + go func() { + c2, err := Dial("tcp", ln.Addr().String()) + if err != nil { + errc <- err + return + } + defer c2.Close() + errc <- peer2(c2.(*TCPConn)) + }() + for i := 0; i < 2; i++ { + if err := <-errc; err != nil { + t.Fatal(err) + } + } +} diff --git a/src/net/writev_test.go b/src/net/writev_test.go new file mode 100644 index 0000000000..cc53adcdd1 --- /dev/null +++ b/src/net/writev_test.go @@ -0,0 +1,171 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "reflect" + "runtime" + "sync" + "testing" +) + +func TestBuffers_read(t *testing.T) { + const story = "once upon a time in Gopherland ... " + buffers := Buffers{ + []byte("once "), + []byte("upon "), + []byte("a "), + []byte("time "), + []byte("in "), + []byte("Gopherland ... "), + } + got, err := ioutil.ReadAll(&buffers) + if err != nil { + t.Fatal(err) + } + if string(got) != story { + t.Errorf("read %q; want %q", got, story) + } + if len(buffers) != 0 { + t.Errorf("len(buffers) = %d; want 0", len(buffers)) + } +} + +func TestBuffers_consume(t *testing.T) { + tests := []struct { + in Buffers + consume int64 + want Buffers + }{ + { + in: Buffers{[]byte("foo"), []byte("bar")}, + consume: 0, + want: Buffers{[]byte("foo"), []byte("bar")}, + }, + { + in: Buffers{[]byte("foo"), []byte("bar")}, + consume: 2, + want: Buffers{[]byte("o"), []byte("bar")}, + }, + { + in: Buffers{[]byte("foo"), []byte("bar")}, + consume: 3, + want: Buffers{[]byte("bar")}, + }, + { + in: Buffers{[]byte("foo"), []byte("bar")}, + consume: 4, + want: Buffers{[]byte("ar")}, + }, + { + in: Buffers{nil, nil, nil, []byte("bar")}, + consume: 1, + want: Buffers{[]byte("ar")}, + }, + { + in: Buffers{nil, nil, nil, []byte("foo")}, + consume: 0, + want: Buffers{[]byte("foo")}, + }, + { + in: Buffers{nil, nil, nil}, + consume: 0, + want: Buffers{}, + }, + } + for i, tt := range tests { + in := tt.in + in.consume(tt.consume) + if !reflect.DeepEqual(in, tt.want) { + t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want) + } + } +} + +func TestBuffers_WriteTo(t *testing.T) { + for _, name := range []string{"WriteTo", "Copy"} { + for _, size := range []int{0, 10, 1023, 1024, 1025} { + t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) { + testBuffer_writeTo(t, size, name == "Copy") + }) + } + } +} + +func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) { + oldHook := testHookDidWritev + defer func() { testHookDidWritev = oldHook }() + var writeLog struct { + sync.Mutex + log []int + } + testHookDidWritev = func(size int) { + writeLog.Lock() + writeLog.log = append(writeLog.log, size) + writeLog.Unlock() + } + var want bytes.Buffer + for i := 0; i < chunks; i++ { + want.WriteByte(byte(i)) + } + + withTCPConnPair(t, func(c *TCPConn) error { + buffers := make(Buffers, chunks) + for i := range buffers { + buffers[i] = want.Bytes()[i : i+1] + } + var n int64 + var err error + if useCopy { + n, err = io.Copy(c, &buffers) + } else { + n, err = buffers.WriteTo(c) + } + if err != nil { + return err + } + if len(buffers) != 0 { + return fmt.Errorf("len(buffers) = %d; want 0", len(buffers)) + } + if n != int64(want.Len()) { + return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len()) + } + return nil + }, func(c *TCPConn) error { + all, err := ioutil.ReadAll(c) + if !bytes.Equal(all, want.Bytes()) || err != nil { + return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes()) + } + + writeLog.Lock() // no need to unlock + var gotSum int + for _, v := range writeLog.log { + gotSum += v + } + + var wantSum int + var wantMinCalls int + switch runtime.GOOS { + case "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd": + wantSum = want.Len() + v := chunks + for v > 0 { + wantMinCalls++ + v -= 1024 + } + } + if len(writeLog.log) < wantMinCalls { + t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls) + } + if gotSum != wantSum { + t.Errorf("writev call sum = %v; want %v", gotSum, wantSum) + } + return nil + }) +} diff --git a/src/net/writev_unix.go b/src/net/writev_unix.go new file mode 100644 index 0000000000..ac4f7cf61a --- /dev/null +++ b/src/net/writev_unix.go @@ -0,0 +1,91 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package net + +import ( + "io" + "os" + "syscall" + "unsafe" +) + +func (c *conn) writeBuffers(v *Buffers) (int64, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + n, err := c.fd.writeBuffers(v) + if err != nil { + return n, &OpError{Op: "writev", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + } + return n, nil +} + +func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) { + if err := fd.writeLock(); err != nil { + return 0, err + } + defer fd.writeUnlock() + if err := fd.pd.prepareWrite(); err != nil { + return 0, err + } + + var iovecs []syscall.Iovec + if fd.iovecs != nil { + iovecs = *fd.iovecs + } + // TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is + // 1024 and this seems conservative enough for now. Darwin's + // UIO_MAXIOV also seems to be 1024. + maxVec := 1024 + + for len(*v) > 0 { + iovecs = iovecs[:0] + for _, chunk := range *v { + if len(chunk) == 0 { + continue + } + iovecs = append(iovecs, syscall.Iovec{Base: &chunk[0]}) + iovecs[len(iovecs)-1].SetLen(len(chunk)) + if len(iovecs) == maxVec { + break + } + } + if len(iovecs) == 0 { + break + } + fd.iovecs = &iovecs // cache + + wrote, _, e0 := syscall.Syscall(syscall.SYS_WRITEV, + uintptr(fd.sysfd), + uintptr(unsafe.Pointer(&iovecs[0])), + uintptr(len(iovecs))) + if wrote < 0 { + wrote = 0 + } + testHookDidWritev(int(wrote)) + n += int64(wrote) + v.consume(int64(wrote)) + if e0 == syscall.EAGAIN { + if err = fd.pd.waitWrite(); err == nil { + continue + } + } else if e0 != 0 { + err = syscall.Errno(e0) + } + if err != nil { + break + } + if n == 0 { + err = io.ErrUnexpectedEOF + break + } + } + if _, ok := err.(syscall.Errno); ok { + err = os.NewSyscallError("writev", err) + } + return n, err +} diff --git a/src/syscall/syscall_nacl.go b/src/syscall/syscall_nacl.go index 88e4e3a9dc..3247505288 100644 --- a/src/syscall/syscall_nacl.go +++ b/src/syscall/syscall_nacl.go @@ -296,3 +296,5 @@ func RouteRIB(facility, param int) ([]byte, error) { return nil, func ParseRoutingMessage(b []byte) ([]RoutingMessage, error) { return nil, ENOSYS } func ParseRoutingSockaddr(msg RoutingMessage) ([]Sockaddr, error) { return nil, ENOSYS } func SysctlUint32(name string) (value uint32, err error) { return 0, ENOSYS } + +type Iovec struct{} // dummy