]> Cypherpunks repositories - gostls13.git/commitdiff
net: check read and write deadlines before doing syscalls
authorBrad Fitzpatrick <bradfitz@golang.org>
Sat, 24 Nov 2012 06:15:26 +0000 (22:15 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 24 Nov 2012 06:15:26 +0000 (22:15 -0800)
Otherwise a fast sender or receiver can make sockets always
readable or writable, preventing deadline checks from ever
occuring.

Update #4191 (fixes it with other CL, coming separately)
Fixes #4403

R=golang-dev, alex.brainman, dave, mikioh.mikioh
CC=golang-dev
https://golang.org/cl/6851096

src/pkg/net/fd_unix.go
src/pkg/net/timeout_test.go

index d87c51ec6631796df1a5ec5dd37b4df39bdc8445..16da53f0f5552af032fe0bf354810e4f1afc6631 100644 (file)
@@ -423,6 +423,12 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
        }
        defer fd.decref()
        for {
+               if fd.rdeadline > 0 {
+                       if time.Now().UnixNano() >= fd.rdeadline {
+                               err = errTimeout
+                               break
+                       }
+               }
                n, err = syscall.Read(int(fd.sysfd), p)
                if err == syscall.EAGAIN {
                        err = errTimeout
@@ -453,6 +459,12 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
        }
        defer fd.decref()
        for {
+               if fd.rdeadline > 0 {
+                       if time.Now().UnixNano() >= fd.rdeadline {
+                               err = errTimeout
+                               break
+                       }
+               }
                n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
                if err == syscall.EAGAIN {
                        err = errTimeout
@@ -481,6 +493,12 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
        }
        defer fd.decref()
        for {
+               if fd.rdeadline > 0 {
+                       if time.Now().UnixNano() >= fd.rdeadline {
+                               err = errTimeout
+                               break
+                       }
+               }
                n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
                if err == syscall.EAGAIN {
                        err = errTimeout
@@ -512,6 +530,12 @@ func (fd *netFD) Write(p []byte) (int, error) {
        var err error
        nn := 0
        for {
+               if fd.wdeadline > 0 {
+                       if time.Now().UnixNano() >= fd.wdeadline {
+                               err = errTimeout
+                               break
+                       }
+               }
                var n int
                n, err = syscall.Write(int(fd.sysfd), p[nn:])
                if n > 0 {
@@ -551,6 +575,12 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
        }
        defer fd.decref()
        for {
+               if fd.wdeadline > 0 {
+                       if time.Now().UnixNano() >= fd.wdeadline {
+                               err = errTimeout
+                               break
+                       }
+               }
                err = syscall.Sendto(fd.sysfd, p, 0, sa)
                if err == syscall.EAGAIN {
                        err = errTimeout
@@ -578,6 +608,12 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
        }
        defer fd.decref()
        for {
+               if fd.wdeadline > 0 {
+                       if time.Now().UnixNano() >= fd.wdeadline {
+                               err = errTimeout
+                               break
+                       }
+               }
                err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
                if err == syscall.EAGAIN {
                        err = errTimeout
index 68d8ced011a2b66dd131352072a1ad288a6f606e..b5b2fa28962e19f94e2a6909c3f0547b1c42798d 100644 (file)
@@ -6,11 +6,24 @@ package net
 
 import (
        "fmt"
+       "io"
+       "io/ioutil"
        "runtime"
        "testing"
        "time"
 )
 
+func isTimeout(err error) bool {
+       e, ok := err.(Error)
+       return ok && e.Timeout()
+}
+
+type copyRes struct {
+       n   int64
+       err error
+       d   time.Duration
+}
+
 func testTimeout(t *testing.T, net, addr string, readFrom bool) {
        c, err := Dial(net, addr)
        if err != nil {
@@ -230,3 +243,191 @@ func TestReadWriteDeadline(t *testing.T) {
        <-quit
        <-lnquit
 }
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+       for i := range p {
+               p[i] = byte(b)
+       }
+       return len(p), nil
+}
+
+func TestVariousDeadlines1Proc(t *testing.T) {
+       testVariousDeadlines(t, 1)
+}
+
+func TestVariousDeadlines4Proc(t *testing.T) {
+       testVariousDeadlines(t, 4)
+}
+
+func testVariousDeadlines(t *testing.T, maxProcs int) {
+       defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+       ln := newLocalListener(t)
+       defer ln.Close()
+       donec := make(chan struct{})
+       defer close(donec)
+
+       testsDone := func() bool {
+               select {
+               case <-donec:
+                       return true
+               }
+               return false
+       }
+
+       // The server, with no timeouts of its own, sending bytes to clients
+       // as fast as it can.
+       servec := make(chan copyRes)
+       go func() {
+               for {
+                       c, err := ln.Accept()
+                       if err != nil {
+                               if !testsDone() {
+                                       t.Fatalf("Accept: %v", err)
+                               }
+                               return
+                       }
+                       go func() {
+                               t0 := time.Now()
+                               n, err := io.Copy(c, neverEnding('a'))
+                               d := time.Since(t0)
+                               c.Close()
+                               servec <- copyRes{n, err, d}
+                       }()
+               }
+       }()
+
+       for _, timeout := range []time.Duration{
+               1 * time.Nanosecond,
+               2 * time.Nanosecond,
+               5 * time.Nanosecond,
+               50 * time.Nanosecond,
+               100 * time.Nanosecond,
+               200 * time.Nanosecond,
+               500 * time.Nanosecond,
+               750 * time.Nanosecond,
+               1 * time.Microsecond,
+               5 * time.Microsecond,
+               25 * time.Microsecond,
+               250 * time.Microsecond,
+               500 * time.Microsecond,
+               1 * time.Millisecond,
+               5 * time.Millisecond,
+               100 * time.Millisecond,
+               250 * time.Millisecond,
+               500 * time.Millisecond,
+               1 * time.Second,
+       } {
+               numRuns := 3
+               if testing.Short() {
+                       numRuns = 1
+                       if timeout > 500*time.Microsecond {
+                               continue
+                       }
+               }
+               for run := 0; run < numRuns; run++ {
+                       name := fmt.Sprintf("%v run %d/%d", timeout, run+1, numRuns)
+                       t.Log(name)
+
+                       c, err := Dial("tcp", ln.Addr().String())
+                       if err != nil {
+                               t.Fatalf("Dial: %v", err)
+                       }
+                       clientc := make(chan copyRes)
+                       go func() {
+                               t0 := time.Now()
+                               c.SetDeadline(t0.Add(timeout))
+                               n, err := io.Copy(ioutil.Discard, c)
+                               d := time.Since(t0)
+                               c.Close()
+                               clientc <- copyRes{n, err, d}
+                       }()
+
+                       const tooLong = 2000 * time.Millisecond
+                       select {
+                       case res := <-clientc:
+                               if isTimeout(res.err) {
+                                       t.Logf("for %v, good client timeout after %v, reading %d bytes", name, res.d, res.n)
+                               } else {
+                                       t.Fatalf("for %v: client Copy = %d, %v (want timeout)", name, res.n, res.err)
+                               }
+                       case <-time.After(tooLong):
+                               t.Fatalf("for %v: timeout (%v) waiting for client to timeout (%v) reading", name, tooLong, timeout)
+                       }
+
+                       select {
+                       case res := <-servec:
+                               t.Logf("for %v: server in %v wrote %d, %v", name, res.d, res.n, res.err)
+                       case <-time.After(tooLong):
+                               t.Fatalf("for %v, timeout waiting for server to finish writing", name)
+                       }
+               }
+       }
+}
+
+// TestReadDeadlineDataAvailable tests that read deadlines work, even
+// if there's data ready to be read.
+func TestReadDeadlineDataAvailable(t *testing.T) {
+       ln := newLocalListener(t)
+       defer ln.Close()
+
+       servec := make(chan copyRes)
+       const msg = "data client shouldn't read, even though it it'll be waiting"
+       go func() {
+               c, err := ln.Accept()
+               if err != nil {
+                       t.Fatalf("Accept: %v", err)
+               }
+               defer c.Close()
+               n, err := c.Write([]byte(msg))
+               servec <- copyRes{n: int64(n), err: err}
+       }()
+
+       c, err := Dial("tcp", ln.Addr().String())
+       if err != nil {
+               t.Fatalf("Dial: %v", err)
+       }
+       defer c.Close()
+       if res := <-servec; res.err != nil || res.n != int64(len(msg)) {
+               t.Fatalf("unexpected server Write: n=%d, err=%d; want n=%d, err=nil", res.n, res.err, len(msg))
+       }
+       c.SetReadDeadline(time.Now().Add(-5 * time.Second)) // in the psat.
+       buf := make([]byte, len(msg)/2)
+       n, err := c.Read(buf)
+       if n > 0 || !isTimeout(err) {
+               t.Fatalf("client read = %d (%q) err=%v; want 0, timeout", n, buf[:n], err)
+       }
+}
+
+// TestWriteDeadlineBufferAvailable tests that write deadlines work, even
+// if there's buffer space available to write.
+func TestWriteDeadlineBufferAvailable(t *testing.T) {
+       ln := newLocalListener(t)
+       defer ln.Close()
+
+       servec := make(chan copyRes)
+       go func() {
+               c, err := ln.Accept()
+               if err != nil {
+                       t.Fatalf("Accept: %v", err)
+               }
+               defer c.Close()
+               c.SetWriteDeadline(time.Now().Add(-5 * time.Second)) // in the past
+               n, err := c.Write([]byte{'x'})
+               servec <- copyRes{n: int64(n), err: err}
+       }()
+
+       c, err := Dial("tcp", ln.Addr().String())
+       if err != nil {
+               t.Fatalf("Dial: %v", err)
+       }
+       defer c.Close()
+       res := <-servec
+       if res.n != 0 {
+               t.Errorf("Write = %d; want 0", res.n)
+       }
+       if !isTimeout(res.err) {
+               t.Errorf("Write error = %v; want timeout", res.err)
+       }
+}