hasBody := resp != nil && resp.ContentLength != 0
var waitForBodyRead chan bool
- if alive {
- if hasBody {
- lastbody = resp.Body
- waitForBodyRead = make(chan bool)
- resp.Body.(*bodyEOFSignal).fn = func() {
- if !pc.t.putIdleConn(pc) {
- alive = false
- }
- waitForBodyRead <- true
- }
- } else {
- // When there's no response body, we immediately
- // reuse the TCP connection (putIdleConn), but
- // we need to prevent ClientConn.Read from
- // closing the Response.Body on the next
- // loop, otherwise it might close the body
- // before the client code has had a chance to
- // read it (even though it'll just be 0, EOF).
- lastbody = nil
-
- if !pc.t.putIdleConn(pc) {
+ if hasBody {
+ lastbody = resp.Body
+ waitForBodyRead = make(chan bool)
+ resp.Body.(*bodyEOFSignal).fn = func() {
+ if alive && !pc.t.putIdleConn(pc) {
alive = false
}
+ waitForBodyRead <- true
+ }
+ }
+
+ if alive && !hasBody {
+ // When there's no response body, we immediately
+ // reuse the TCP connection (putIdleConn), but
+ // we need to prevent ClientConn.Read from
+ // closing the Response.Body on the next
+ // loop, otherwise it might close the body
+ // before the client code has had a chance to
+ // read it (even though it'll just be 0, EOF).
+ lastbody = nil
+
+ if !pc.t.putIdleConn(pc) {
+ alive = false
}
}
// before we race and peek on the underlying bufio reader.
if waitForBodyRead != nil {
<-waitForBodyRead
- } else if !alive {
- // If waitForBodyRead is nil, and we're not alive, we
- // must close the connection before we leave the loop.
+ }
+
+ if !alive {
pc.close()
}
}
}
type testConnSet struct {
- set map[net.Conn]bool
- mutex sync.Mutex
+ closed map[net.Conn]bool
+ list []net.Conn // in order created
+ mutex sync.Mutex
}
func (tcs *testConnSet) insert(c net.Conn) {
tcs.mutex.Lock()
defer tcs.mutex.Unlock()
- tcs.set[c] = true
+ tcs.closed[c] = false
+ tcs.list = append(tcs.list, c)
}
func (tcs *testConnSet) remove(c net.Conn) {
tcs.mutex.Lock()
defer tcs.mutex.Unlock()
- // just change to false, so we have a full set of opened connections
- tcs.set[c] = false
+ tcs.closed[c] = true
}
// some tests use this to manage raw tcp connections for later inspection
func makeTestDial() (*testConnSet, func(n, addr string) (net.Conn, error)) {
connSet := &testConnSet{
- set: make(map[net.Conn]bool),
+ closed: make(map[net.Conn]bool),
}
dial := func(n, addr string) (net.Conn, error) {
c, err := net.Dial(n, addr)
return connSet, dial
}
-func (tcs *testConnSet) countClosed() (closed, total int) {
+func (tcs *testConnSet) check(t *testing.T) {
tcs.mutex.Lock()
defer tcs.mutex.Unlock()
- total = len(tcs.set)
- for _, open := range tcs.set {
- if !open {
- closed += 1
+ for i, c := range tcs.list {
+ if !tcs.closed[c] {
+ // TODO(bradfitz,gustavo): make the following
+ // line an Errorf, not Logf, once issue 3540
+ // is fixed again.
+ t.Logf("TCP connection #%d (of %d total) was not closed", i+1, len(tcs.list))
}
}
- return
}
// Two subsequent requests and verify their response is the same.
tr.CloseIdleConnections()
}
- closed, total := connSet.countClosed()
- if closed < total {
- t.Errorf("%d out of %d tcp connections were not closed", total-closed, total)
- }
+ connSet.check(t)
}
func TestTransportConnectionCloseOnRequest(t *testing.T) {
tr.CloseIdleConnections()
}
- closed, total := connSet.countClosed()
- if closed < total {
- t.Errorf("%d out of %d tcp connections were not closed", total-closed, total)
- }
+ connSet.check(t)
}
func TestTransportIdleCacheKeys(t *testing.T) {
<-didreq
}
+// Test that the transport doesn't close the TCP connection early,
+// before the response body has been read. This was a regression
+// which sadly lacked a triggering test. The large response body made
+// the old race easier to trigger.
+func TestIssue3644(t *testing.T) {
+ const numFoos = 5000
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ for i := 0; i < numFoos; i++ {
+ w.Write([]byte("foo "))
+ }
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ bs, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != numFoos*len("foo ") {
+ t.Errorf("unexpected response length")
+ }
+}
+
type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) {