package net
import (
- "errors"
"io"
+ "sync"
"time"
)
+// pipeDeadline is an abstraction for handling timeouts.
+type pipeDeadline struct {
+ mu sync.Mutex // Guards timer and cancel
+ timer *time.Timer
+ cancel chan struct{} // Must be non-nil
+}
+
+func makePipeDeadline() pipeDeadline {
+ return pipeDeadline{cancel: make(chan struct{})}
+}
+
+// set sets the point in time when the deadline will time out.
+// A timeout event is signaled by closing the channel returned by waiter.
+// Once a timeout has occurred, the deadline can be refreshed by specifying a
+// t value in the future.
+//
+// A zero value for t prevents timeout.
+func (d *pipeDeadline) set(t time.Time) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.timer != nil && !d.timer.Stop() {
+ <-d.cancel // Wait for the timer callback to finish and close cancel
+ }
+ d.timer = nil
+
+ // Time is zero, then there is no deadline.
+ closed := isClosedChan(d.cancel)
+ if t.IsZero() {
+ if closed {
+ d.cancel = make(chan struct{})
+ }
+ return
+ }
+
+ // Time in the future, setup a timer to cancel in the future.
+ if dur := time.Until(t); dur > 0 {
+ if closed {
+ d.cancel = make(chan struct{})
+ }
+ d.timer = time.AfterFunc(dur, func() {
+ close(d.cancel)
+ })
+ return
+ }
+
+ // Time in the past, so close immediately.
+ if !closed {
+ close(d.cancel)
+ }
+}
+
+// wait returns a channel that is closed when the deadline is exceeded.
+func (d *pipeDeadline) wait() chan struct{} {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.cancel
+}
+
+func isClosedChan(c <-chan struct{}) bool {
+ select {
+ case <-c:
+ return true
+ default:
+ return false
+ }
+}
+
+type pipeError struct {
+ errStr string
+ timeout bool
+}
+
+func (pe pipeError) Error() string { return pe.errStr }
+func (pe pipeError) Timeout() bool { return pe.timeout }
+func (pe pipeError) Temporary() bool { return pe.timeout }
+
+var (
+ errDeadline = pipeError{"deadline exceeded", true}
+ errClosed = pipeError{"closed connection", false}
+)
+
+type pipeAddr struct{}
+
+func (pipeAddr) Network() string { return "pipe" }
+func (pipeAddr) String() string { return "pipe" }
+
+type pipe struct {
+ wrMu sync.Mutex // Serialize Write operations
+
+ // Used by local Read to interact with remote Write.
+ // Successful receive on rdRx is always followed by send on rdTx.
+ rdRx <-chan []byte
+ rdTx chan<- int
+
+ // Used by local Write to interact with remote Read.
+ // Successful send on wrTx is always followed by receive on wrRx.
+ wrTx chan<- []byte
+ wrRx <-chan int
+
+ once sync.Once // Protects closing localDone
+ localDone chan struct{}
+ remoteDone <-chan struct{}
+
+ readDeadline pipeDeadline
+ writeDeadline pipeDeadline
+}
+
// Pipe creates a synchronous, in-memory, full duplex
// network connection; both ends implement the Conn interface.
// Reads on one end are matched with writes on the other,
// copying data directly between the two; there is no internal
// buffering.
func Pipe() (Conn, Conn) {
- r1, w1 := io.Pipe()
- r2, w2 := io.Pipe()
+ cb1 := make(chan []byte)
+ cb2 := make(chan []byte)
+ cn1 := make(chan int)
+ cn2 := make(chan int)
+ done1 := make(chan struct{})
+ done2 := make(chan struct{})
- return &pipe{r1, w2}, &pipe{r2, w1}
-}
-
-type pipe struct {
- *io.PipeReader
- *io.PipeWriter
+ p1 := &pipe{
+ rdRx: cb1, rdTx: cn1,
+ wrTx: cb2, wrRx: cn2,
+ localDone: done1, remoteDone: done2,
+ readDeadline: makePipeDeadline(),
+ writeDeadline: makePipeDeadline(),
+ }
+ p2 := &pipe{
+ rdRx: cb2, rdTx: cn2,
+ wrTx: cb1, wrRx: cn1,
+ localDone: done2, remoteDone: done1,
+ readDeadline: makePipeDeadline(),
+ writeDeadline: makePipeDeadline(),
+ }
+ return p1, p2
}
-type pipeAddr int
+func (*pipe) LocalAddr() Addr { return pipeAddr{} }
+func (*pipe) RemoteAddr() Addr { return pipeAddr{} }
-func (pipeAddr) Network() string {
- return "pipe"
+func (p *pipe) Read(b []byte) (int, error) {
+ n, err := p.read(b)
+ if err != nil && err != io.EOF {
+ err = &OpError{Op: "read", Net: "pipe", Err: err}
+ }
+ return n, err
}
-func (pipeAddr) String() string {
- return "pipe"
-}
+func (p *pipe) read(b []byte) (n int, err error) {
+ switch {
+ case isClosedChan(p.localDone):
+ return 0, errClosed
+ case isClosedChan(p.remoteDone):
+ return 0, io.EOF
+ case isClosedChan(p.readDeadline.wait()):
+ return 0, errDeadline
+ }
-func (p *pipe) Close() error {
- err := p.PipeReader.Close()
- err1 := p.PipeWriter.Close()
- if err == nil {
- err = err1
+ select {
+ case bw := <-p.rdRx:
+ nr := copy(b, bw)
+ p.rdTx <- nr
+ return nr, nil
+ case <-p.localDone:
+ return 0, errClosed
+ case <-p.remoteDone:
+ return 0, io.EOF
+ case <-p.readDeadline.wait():
+ return 0, errDeadline
}
- return err
}
-func (p *pipe) LocalAddr() Addr {
- return pipeAddr(0)
+func (p *pipe) Write(b []byte) (int, error) {
+ n, err := p.write(b)
+ if err != nil {
+ err = &OpError{Op: "write", Net: "pipe", Err: err}
+ }
+ return n, err
}
-func (p *pipe) RemoteAddr() Addr {
- return pipeAddr(0)
+func (p *pipe) write(b []byte) (n int, err error) {
+ switch {
+ case isClosedChan(p.localDone):
+ return 0, errClosed
+ case isClosedChan(p.remoteDone):
+ return 0, errClosed
+ case isClosedChan(p.writeDeadline.wait()):
+ return 0, errDeadline
+ }
+
+ p.wrMu.Lock() // Ensure entirety of b is written together
+ defer p.wrMu.Unlock()
+ for once := true; once || len(b) > 0; once = false {
+ select {
+ case p.wrTx <- b:
+ nw := <-p.wrRx
+ b = b[nw:]
+ n += nw
+ case <-p.localDone:
+ return n, errClosed
+ case <-p.remoteDone:
+ return n, errClosed
+ case <-p.writeDeadline.wait():
+ return n, errDeadline
+ }
+ }
+ return n, nil
}
func (p *pipe) SetDeadline(t time.Time) error {
- return &OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
+ if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
+ return &OpError{Op: "set", Net: "pipe", Err: errClosed}
+ }
+ p.readDeadline.set(t)
+ p.writeDeadline.set(t)
+ return nil
}
func (p *pipe) SetReadDeadline(t time.Time) error {
- return &OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
+ if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
+ return &OpError{Op: "set", Net: "pipe", Err: errClosed}
+ }
+ p.readDeadline.set(t)
+ return nil
}
func (p *pipe) SetWriteDeadline(t time.Time) error {
- return &OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
+ if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
+ return &OpError{Op: "set", Net: "pipe", Err: errClosed}
+ }
+ p.writeDeadline.set(t)
+ return nil
+}
+
+func (p *pipe) Close() error {
+ p.once.Do(func() { close(p.localDone) })
+ return nil
}
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package net
+package net_test
import (
- "bytes"
- "io"
+ "net"
"testing"
-)
-
-func checkPipeWrite(t *testing.T, w io.Writer, data []byte, c chan int) {
- n, err := w.Write(data)
- if err != nil {
- t.Error(err)
- }
- if n != len(data) {
- t.Errorf("short write: %d != %d", n, len(data))
- }
- c <- 0
-}
-func checkPipeRead(t *testing.T, r io.Reader, data []byte, wantErr error) {
- buf := make([]byte, len(data)+10)
- n, err := r.Read(buf)
- if err != wantErr {
- t.Error(err)
- return
- }
- if n != len(data) || !bytes.Equal(buf[0:n], data) {
- t.Errorf("bad read: got %q", buf[0:n])
- return
- }
-}
+ "golang_org/x/net/nettest"
+)
-// TestPipe tests a simple read/write/close sequence.
-// Assumes that the underlying io.Pipe implementation
-// is solid and we're just testing the net wrapping.
func TestPipe(t *testing.T) {
- c := make(chan int)
- cli, srv := Pipe()
- go checkPipeWrite(t, cli, []byte("hello, world"), c)
- checkPipeRead(t, srv, []byte("hello, world"), nil)
- <-c
- go checkPipeWrite(t, srv, []byte("line 2"), c)
- checkPipeRead(t, cli, []byte("line 2"), nil)
- <-c
- go checkPipeWrite(t, cli, []byte("a third line"), c)
- checkPipeRead(t, srv, []byte("a third line"), nil)
- <-c
- go srv.Close()
- checkPipeRead(t, cli, nil, io.EOF)
- cli.Close()
+ nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
+ c1, c2 = net.Pipe()
+ stop = func() {
+ c1.Close()
+ c2.Close()
+ }
+ return
+ })
}