"fmt"
"io"
"io/ioutil"
+ "net"
. "net/http"
"net/http/httptest"
"net/url"
"runtime"
"strconv"
"strings"
+ "sync"
"testing"
"time"
)
w.Write([]byte(r.RemoteAddr))
})
+type testCloseConn struct {
+ net.Conn
+ set *testConnSet
+}
+
+func (conn *testCloseConn) Close() error {
+ conn.set.remove(conn)
+ return conn.Conn.Close()
+}
+
+type testConnSet struct {
+ set map[net.Conn]bool
+ mutex sync.Mutex
+}
+
+func (tcs *testConnSet) insert(c net.Conn) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+ tcs.set[c] = true
+}
+
+func (tcs *testConnSet) remove(c net.Conn) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+ // just change to false, so we have a full set of opened connections
+ tcs.set[c] = false
+}
+
+// some tests use this to manage raw tcp connections for later inspection
+func makeTestDial() (*testConnSet, func(n, addr string) (net.Conn, error)) {
+ connSet := &testConnSet{
+ set: make(map[net.Conn]bool),
+ }
+ dial := func(n, addr string) (net.Conn, error) {
+ c, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ tc := &testCloseConn{c, connSet}
+ connSet.insert(tc)
+ return tc, nil
+ }
+ return connSet, dial
+}
+
+func (tcs *testConnSet) countClosed() (closed, total int) {
+ tcs.mutex.Lock()
+ defer tcs.mutex.Unlock()
+
+ total = len(tcs.set)
+ for _, open := range tcs.set {
+ if !open {
+ closed += 1
+ }
+ }
+ return
+}
+
// Two subsequent requests and verify their response is the same.
// The response from the server is our own IP:port
func TestTransportKeepAlives(t *testing.T) {
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
+ connSet, testDial := makeTestDial()
+
for _, connectionClose := range []bool{false, true} {
- tr := &Transport{}
+ tr := &Transport{
+ Dial: testDial,
+ }
c := &Client{Transport: tr}
fetch := func(n int) string {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
+
+ tr.CloseIdleConnections()
+ }
+
+ closed, total := connSet.countClosed()
+ if closed < total {
+ t.Errorf("%d out of %d tcp connections were not closed", total-closed, total)
}
}
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
+ connSet, testDial := makeTestDial()
+
for _, connectionClose := range []bool{false, true} {
- tr := &Transport{}
+ tr := &Transport{
+ Dial: testDial,
+ }
c := &Client{Transport: tr}
fetch := func(n int) string {
t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
connectionClose, bodiesDiffer, body1, body2)
}
+
+ tr.CloseIdleConnections()
+ }
+
+ closed, total := connSet.countClosed()
+ if closed < total {
+ t.Errorf("%d out of %d tcp connections were not closed", total-closed, total)
}
}