]> Cypherpunks repositories - gostls13.git/commitdiff
io: add error check on pipe close functions to avoid error overwriting
authorJordi Martin <jordimartin@gmail.com>
Wed, 7 Aug 2019 11:14:38 +0000 (11:14 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 28 Aug 2019 18:35:24 +0000 (18:35 +0000)
The current implementation allows multiple calls `Close` and `CloseWithError` in every side of the pipe, as a result, the original error can be overwritten.

This CL fixes this behavior adding an error existence check on `atomicError` type
and keeping the first error still available.

Fixes #24283

Change-Id: Iefe8f758aeb775309424365f8177511062514150
GitHub-Last-Rev: b559540d7af3a0dad423816b695525ac2d6bd864
GitHub-Pull-Request: golang/go#33239
Reviewed-on: https://go-review.googlesource.com/c/go/+/187197
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Run-TryBot: Joe Tsai <thebrokentoaster@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/io/pipe.go
src/io/pipe_test.go

index 4efaf2f8e481381360eff85756be70baaf3a545c..b5343bb6b73b50d4a75d86d7c1ece87abd38d4ed 100644 (file)
@@ -10,19 +10,26 @@ package io
 import (
        "errors"
        "sync"
-       "sync/atomic"
 )
 
-// atomicError is a type-safe atomic value for errors.
-// We use a struct{ error } to ensure consistent use of a concrete type.
-type atomicError struct{ v atomic.Value }
+// onceError is an object that will only store an error once.
+type onceError struct {
+       sync.Mutex // guards following
+       err        error
+}
 
-func (a *atomicError) Store(err error) {
-       a.v.Store(struct{ error }{err})
+func (a *onceError) Store(err error) {
+       a.Lock()
+       defer a.Unlock()
+       if a.err != nil {
+               return
+       }
+       a.err = err
 }
-func (a *atomicError) Load() error {
-       err, _ := a.v.Load().(struct{ error })
-       return err.error
+func (a *onceError) Load() error {
+       a.Lock()
+       defer a.Unlock()
+       return a.err
 }
 
 // ErrClosedPipe is the error used for read or write operations on a closed pipe.
@@ -36,8 +43,8 @@ type pipe struct {
 
        once sync.Once // Protects closing done
        done chan struct{}
-       rerr atomicError
-       werr atomicError
+       rerr onceError
+       werr onceError
 }
 
 func (p *pipe) Read(b []byte) (n int, err error) {
@@ -135,6 +142,9 @@ func (r *PipeReader) Close() error {
 
 // CloseWithError closes the reader; subsequent writes
 // to the write half of the pipe will return the error err.
+//
+// CloseWithError never overwrites the previous error if it exists
+// and always returns nil.
 func (r *PipeReader) CloseWithError(err error) error {
        return r.p.CloseRead(err)
 }
@@ -163,7 +173,8 @@ func (w *PipeWriter) Close() error {
 // read half of the pipe will return no bytes and the error err,
 // or EOF if err is nil.
 //
-// CloseWithError always returns nil.
+// CloseWithError never overwrites the previous error if it exists
+// and always returns nil.
 func (w *PipeWriter) CloseWithError(err error) error {
        return w.p.CloseWrite(err)
 }
index f18b1c45f8b5c553df74176402b9fe0604baeb4a..8973360740181ba0579e966d840c2f4d10c18f9a 100644 (file)
@@ -326,8 +326,8 @@ func TestPipeCloseError(t *testing.T) {
                t.Errorf("Write error: got %T, want testError1", err)
        }
        r.CloseWithError(testError2{})
-       if _, err := w.Write(nil); err != (testError2{}) {
-               t.Errorf("Write error: got %T, want testError2", err)
+       if _, err := w.Write(nil); err != (testError1{}) {
+               t.Errorf("Write error: got %T, want testError1", err)
        }
 
        r, w = Pipe()
@@ -336,8 +336,8 @@ func TestPipeCloseError(t *testing.T) {
                t.Errorf("Read error: got %T, want testError1", err)
        }
        w.CloseWithError(testError2{})
-       if _, err := r.Read(nil); err != (testError2{}) {
-               t.Errorf("Read error: got %T, want testError2", err)
+       if _, err := r.Read(nil); err != (testError1{}) {
+               t.Errorf("Read error: got %T, want testError1", err)
        }
 }