]> Cypherpunks repositories - gostls13.git/commitdiff
net: break up >1GB reads and writes on stream connections
authorRuss Cox <rsc@golang.org>
Thu, 20 Oct 2016 21:14:47 +0000 (17:14 -0400)
committerRuss Cox <rsc@golang.org>
Sat, 29 Oct 2016 18:23:59 +0000 (18:23 +0000)
Also fix behavior of Buffers.WriteTo when writev returns an error.

Fixes #16266.

Change-Id: Idc9503408ce2cb460663768fab86035cbab11aef
Reviewed-on: https://go-review.googlesource.com/31584
Run-TryBot: Russ Cox <rsc@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/net/fd_plan9.go
src/net/fd_unix.go
src/net/fd_windows.go
src/net/tcpsock_test.go
src/net/writev_test.go
src/net/writev_unix.go

index d32b622966c38f50ee0e3bde3282764065dac0ba..ab5db38dbec0c2aba7be4f7712b7178595827ce0 100644 (file)
@@ -22,6 +22,7 @@ type netFD struct {
        dir               string
        listen, ctl, data *os.File
        laddr, raddr      Addr
+       isStream          bool
 }
 
 var (
index 1296bc56b2021eb57dd37ca93a8df45cb264b80c..3c95fc01d401dbcad0162f17e571f8ea656ffe2c 100644 (file)
@@ -24,6 +24,7 @@ type netFD struct {
        sysfd       int
        family      int
        sotype      int
+       isStream    bool
        isConnected bool
        net         string
        laddr       Addr
@@ -40,7 +41,7 @@ func sysInit() {
 }
 
 func newFD(sysfd, family, sotype int, net string) (*netFD, error) {
-       return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil
+       return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net, isStream: sotype == syscall.SOCK_STREAM}, nil
 }
 
 func (fd *netFD) init() error {
@@ -238,6 +239,9 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
        if err := fd.pd.prepareRead(); err != nil {
                return 0, err
        }
+       if fd.isStream && len(p) > 1<<30 {
+               p = p[:1<<30]
+       }
        for {
                n, err = syscall.Read(fd.sysfd, p)
                if err != nil {
@@ -321,7 +325,11 @@ func (fd *netFD) Write(p []byte) (nn int, err error) {
        }
        for {
                var n int
-               n, err = syscall.Write(fd.sysfd, p[nn:])
+               max := len(p)
+               if fd.isStream && max-nn > 1<<30 {
+                       max = nn + 1<<30
+               }
+               n, err = syscall.Write(fd.sysfd, p[nn:max])
                if n > 0 {
                        nn += n
                }
index d1c368e8836fa3a7231734fa509c81017a9bbc5b..828da4a2e62fa0b8c536658a541763c8a14411ec 100644 (file)
@@ -239,6 +239,7 @@ type netFD struct {
        sysfd         syscall.Handle
        family        int
        sotype        int
+       isStream      bool
        isConnected   bool
        skipSyncNotif bool
        net           string
@@ -257,7 +258,7 @@ func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error)
                return nil, initErr
        }
        onceStartServer.Do(startServer)
-       return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil
+       return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net, isStream: sotype == syscall.SOCK_STREAM}, nil
 }
 
 func (fd *netFD) init() error {
index d80a3736bfb437f0cd62cd893386f3be2b0a8d68..0d283dfa4f78b3249f8a18ab86bcc851e44dc323 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "fmt"
        "internal/testenv"
        "io"
        "reflect"
@@ -654,3 +655,58 @@ func TestTCPSelfConnect(t *testing.T) {
                }
        }
 }
+
+// Test that >32-bit reads work on 64-bit systems.
+// On 32-bit systems this tests that maxint reads work.
+func TestTCPBig(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping test in short mode")
+       }
+
+       for _, writev := range []bool{false, true} {
+               t.Run(fmt.Sprintf("writev=%v", writev), func(t *testing.T) {
+                       ln, err := newLocalListener("tcp")
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       defer ln.Close()
+
+                       x := int(1 << 30)
+                       x = x*5 + 1<<20 // just over 5 GB on 64-bit, just over 1GB on 32-bit
+                       done := make(chan int)
+                       go func() {
+                               defer close(done)
+                               c, err := ln.Accept()
+                               if err != nil {
+                                       t.Error(err)
+                                       return
+                               }
+                               buf := make([]byte, x)
+                               var n int
+                               if writev {
+                                       var n64 int64
+                                       n64, err = (&Buffers{buf}).WriteTo(c)
+                                       n = int(n64)
+                               } else {
+                                       n, err = c.Write(buf)
+                               }
+                               if n != len(buf) || err != nil {
+                                       t.Errorf("Write(buf) = %d, %v, want %d, nil", n, err, x)
+                               }
+                               c.Close()
+                       }()
+
+                       c, err := Dial("tcp", ln.Addr().String())
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       buf := make([]byte, x)
+                       n, err := io.ReadFull(c, buf)
+                       if n != len(buf) || err != nil {
+                               t.Errorf("Read(buf) = %d, %v, want %d, nil", n, err, x)
+                       }
+                       c.Close()
+                       <-done
+               })
+       }
+}
index cc53adcdd1ffe6d0125725054d3633b7aa286732..385cc1250392bee1cc63832e5e8b7a0a7523ecda 100644 (file)
@@ -169,3 +169,44 @@ func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
                return nil
        })
 }
+
+func TestWritevError(t *testing.T) {
+       ln, err := newLocalListener("tcp")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer ln.Close()
+
+       ch := make(chan Conn, 1)
+       go func() {
+               defer close(ch)
+               c, err := ln.Accept()
+               if err != nil {
+                       t.Error(err)
+                       return
+               }
+               ch <- c
+       }()
+       c1, err := Dial("tcp", ln.Addr().String())
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer c1.Close()
+       c2 := <-ch
+       if c2 == nil {
+               t.Fatal("no server side connection")
+       }
+       c2.Close()
+
+       // 1 GB of data should be enough to notice the connection is gone.
+       // Just a few bytes is not enough.
+       // Arrange to reuse the same 1 MB buffer so that we don't allocate much.
+       buf := make([]byte, 1<<20)
+       buffers := make(Buffers, 1<<10)
+       for i := range buffers {
+               buffers[i] = buf
+       }
+       if _, err := buffers.WriteTo(c1); err == nil {
+               t.Fatalf("Buffers.WriteTo(closed conn) succeeded, want error", err)
+       }
+}
index ac4f7cf61adbe314f40cb9fbce0c559c72b0d4e4..174e6bc51e3a65b8f3546ad7c4c88f966cf006ee 100644 (file)
@@ -49,6 +49,10 @@ func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) {
                                continue
                        }
                        iovecs = append(iovecs, syscall.Iovec{Base: &chunk[0]})
+                       if fd.isStream && len(chunk) > 1<<30 {
+                               iovecs[len(iovecs)-1].SetLen(1 << 30)
+                               break // continue chunk on next writev
+                       }
                        iovecs[len(iovecs)-1].SetLen(len(chunk))
                        if len(iovecs) == maxVec {
                                break
@@ -63,7 +67,7 @@ func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) {
                        uintptr(fd.sysfd),
                        uintptr(unsafe.Pointer(&iovecs[0])),
                        uintptr(len(iovecs)))
-               if wrote < 0 {
+               if wrote == ^uintptr(0) {
                        wrote = 0
                }
                testHookDidWritev(int(wrote))