]> Cypherpunks repositories - gostls13.git/commitdiff
os/exec: start checking for context cancelation in Start
authorIan Lance Taylor <iant@golang.org>
Thu, 30 Jun 2016 02:20:51 +0000 (19:20 -0700)
committerIan Lance Taylor <iant@golang.org>
Thu, 30 Jun 2016 16:35:56 +0000 (16:35 +0000)
Previously we started checking for context cancelation in Wait, but
that meant that when using StdoutPipe context cancelation never took
effect.

Fixes #16222.

Change-Id: I89cd26d3499a6080bf1a07718ce38d825561899e
Reviewed-on: https://go-review.googlesource.com/24650
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/os/exec/exec.go
src/os/exec/exec_test.go

index 10300ce234a75120340a183d89c350f954ed5cc4..d2c1b17e509997078bd9e3b20af6a56aabecc4b7 100644 (file)
@@ -111,6 +111,7 @@ type Cmd struct {
        closeAfterWait  []io.Closer
        goroutine       []func() error
        errch           chan error // one send per goroutine
+       waitDone        chan struct{}
 }
 
 // Command returns the Cmd struct to execute the named program with
@@ -326,6 +327,15 @@ func (c *Cmd) Start() error {
        if c.Process != nil {
                return errors.New("exec: already started")
        }
+       if c.ctx != nil {
+               select {
+               case <-c.ctx.Done():
+                       c.closeDescriptors(c.closeAfterStart)
+                       c.closeDescriptors(c.closeAfterWait)
+                       return c.ctx.Err()
+               default:
+               }
+       }
 
        type F func(*Cmd) (*os.File, error)
        for _, setupFd := range []F{(*Cmd).stdin, (*Cmd).stdout, (*Cmd).stderr} {
@@ -361,6 +371,17 @@ func (c *Cmd) Start() error {
                }(fn)
        }
 
+       if c.ctx != nil {
+               c.waitDone = make(chan struct{})
+               go func() {
+                       select {
+                       case <-c.ctx.Done():
+                               c.Process.Kill()
+                       case <-c.waitDone:
+                       }
+               }()
+       }
+
        return nil
 }
 
@@ -410,20 +431,9 @@ func (c *Cmd) Wait() error {
        }
        c.finished = true
 
-       var waitDone chan struct{}
-       if c.ctx != nil {
-               waitDone = make(chan struct{})
-               go func() {
-                       select {
-                       case <-c.ctx.Done():
-                               c.Process.Kill()
-                       case <-waitDone:
-                       }
-               }()
-       }
        state, err := c.Process.Wait()
-       if waitDone != nil {
-               close(waitDone)
+       if c.waitDone != nil {
+               close(c.waitDone)
        }
        c.ProcessState = state
 
index 41f9dfe1c6361453a37d56a0808553722042f7e0..4cc984772155b168535d0e0422decc0558ccfd07 100644 (file)
@@ -878,3 +878,73 @@ func TestContext(t *testing.T) {
                t.Fatal("timeout waiting for child process death")
        }
 }
+
+func TestContextCancel(t *testing.T) {
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       c := helperCommandContext(t, ctx, "cat")
+
+       r, w, err := os.Pipe()
+       if err != nil {
+               t.Fatal(err)
+       }
+       c.Stdin = r
+
+       stdout, err := c.StdoutPipe()
+       if err != nil {
+               t.Fatal(err)
+       }
+       readDone := make(chan struct{})
+       go func() {
+               defer close(readDone)
+               var a [1024]byte
+               for {
+                       n, err := stdout.Read(a[:])
+                       if err != nil {
+                               if err != io.EOF {
+                                       t.Errorf("unexpected read error: %v", err)
+                               }
+                               return
+                       }
+                       t.Logf("%s", a[:n])
+               }
+       }()
+
+       if err := c.Start(); err != nil {
+               t.Fatal(err)
+       }
+
+       if err := r.Close(); err != nil {
+               t.Fatal(err)
+       }
+
+       if _, err := io.WriteString(w, "echo"); err != nil {
+               t.Fatal(err)
+       }
+
+       cancel()
+
+       // Calling cancel should have killed the process, so writes
+       // should now fail.  Give the process a little while to die.
+       start := time.Now()
+       for {
+               if _, err := io.WriteString(w, "echo"); err != nil {
+                       break
+               }
+               if time.Since(start) > time.Second {
+                       t.Fatal("cancelling context did not stop program")
+               }
+               time.Sleep(time.Millisecond)
+       }
+
+       if err := w.Close(); err != nil {
+               t.Error("error closing write end of pipe: %v", err)
+       }
+       <-readDone
+
+       if err := c.Wait(); err == nil {
+               t.Error("program unexpectedly exited successfully")
+       } else {
+               t.Logf("exit status: %v", err)
+       }
+}