]> Cypherpunks repositories - gostls13.git/commitdiff
http: keep gzip reader inside eofsignaler
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 27 Apr 2011 21:23:25 +0000 (14:23 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 27 Apr 2011 21:23:25 +0000 (14:23 -0700)
Fixes #1725

R=rsc
CC=golang-dev
https://golang.org/cl/4442086

src/pkg/http/response_test.go
src/pkg/http/transport.go
src/pkg/http/transport_test.go

index 5e76bbb9e1f60ad6910efdbdd9f922ac3acee0b3..9e77c20c40b72be1afd3fc634418d6b3d8fa45e6 100644 (file)
@@ -7,11 +7,13 @@ package http
 import (
        "bufio"
        "bytes"
+       "compress/gzip"
+       "crypto/rand"
        "fmt"
+       "os"
        "io"
        "io/ioutil"
        "reflect"
-       "strings"
        "testing"
 )
 
@@ -277,64 +279,100 @@ func TestReadResponse(t *testing.T) {
        }
 }
 
-// TestReadResponseCloseInMiddle tests that for both chunked and unchunked responses,
-// if we close the Body while only partway through reading, the underlying reader
-// advanced to the end of the request.
+var readResponseCloseInMiddleTests = []struct {
+       chunked, compressed bool
+}{
+       {false, false},
+       {true, false},
+       {true, true},
+}
+
+// TestReadResponseCloseInMiddle tests that closing a body after
+// reading only part of its contents advances the read to the end of
+// the request, right up until the next request.
 func TestReadResponseCloseInMiddle(t *testing.T) {
-       for _, chunked := range []bool{false, true} {
+       for _, test := range readResponseCloseInMiddleTests {
+               fatalf := func(format string, args ...interface{}) {
+                       args = append([]interface{}{test.chunked, test.compressed}, args...)
+                       t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...)
+               }
+               checkErr := func(err os.Error, msg string) {
+                       if err == nil {
+                               return
+                       }
+                       fatalf(msg+": %v", err)
+               }
                var buf bytes.Buffer
                buf.WriteString("HTTP/1.1 200 OK\r\n")
-               if chunked {
-                       buf.WriteString("Transfer-Encoding: chunked\r\n\r\n")
+               if test.chunked {
+                       buf.WriteString("Transfer-Encoding: chunked\r\n")
                } else {
-                       buf.WriteString("Content-Length: 1000000\r\n\r\n")
+                       buf.WriteString("Content-Length: 1000000\r\n")
                }
-               chunk := strings.Repeat("x", 1000)
+               var wr io.Writer = &buf
+               if test.chunked {
+                       wr = &chunkedWriter{wr}
+               }
+               if test.compressed {
+                       buf.WriteString("Content-Encoding: gzip\r\n")
+                       var err os.Error
+                       wr, err = gzip.NewWriter(wr)
+                       checkErr(err, "gzip.NewWriter")
+               }
+               buf.WriteString("\r\n")
+
+               chunk := bytes.Repeat([]byte{'x'}, 1000)
                for i := 0; i < 1000; i++ {
-                       if chunked {
-                               buf.WriteString("03E8\r\n")
-                               buf.WriteString(chunk)
-                               buf.WriteString("\r\n")
-                       } else {
-                               buf.WriteString(chunk)
+                       if test.compressed {
+                               // Otherwise this compresses too well.
+                               _, err := io.ReadFull(rand.Reader, chunk)
+                               checkErr(err, "rand.Reader ReadFull")
                        }
+                       wr.Write(chunk)
                }
-               if chunked {
+               if test.compressed {
+                       err := wr.(*gzip.Compressor).Close()
+                       checkErr(err, "compressor close")
+               }
+               if test.chunked {
                        buf.WriteString("0\r\n\r\n")
                }
                buf.WriteString("Next Request Here")
+
                bufr := bufio.NewReader(&buf)
                resp, err := ReadResponse(bufr, "GET")
-               if err != nil {
-                       t.Fatalf("parse error for chunked=%v: %v", chunked, err)
-               }
-
+               checkErr(err, "ReadResponse")
                expectedLength := int64(-1)
-               if !chunked {
+               if !test.chunked {
                        expectedLength = 1000000
                }
                if resp.ContentLength != expectedLength {
-                       t.Fatalf("chunked=%v: expected response length %d, got %d", chunked, expectedLength, resp.ContentLength)
+                       fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength)
+               }
+               if resp.Body == nil {
+                       fatalf("nil body")
+               }
+               if test.compressed {
+                       gzReader, err := gzip.NewReader(resp.Body)
+                       checkErr(err, "gzip.NewReader")
+                       resp.Body = &readFirstCloseBoth{gzReader, resp.Body}
                }
+
                rbuf := make([]byte, 2500)
                n, err := io.ReadFull(resp.Body, rbuf)
-               if err != nil {
-                       t.Fatalf("ReadFull error for chunked=%v: %v", chunked, err)
-               }
+               checkErr(err, "2500 byte ReadFull")
                if n != 2500 {
-                       t.Fatalf("ReadFull only read %n bytes for chunked=%v", n, chunked)
+                       fatalf("ReadFull only read %d bytes", n)
                }
-               if !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) {
-                       t.Fatalf("ReadFull didn't read 2500 'x' for chunked=%v; got %q", chunked, string(rbuf))
+               if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) {
+                       fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf))
                }
                resp.Body.Close()
 
                rest, err := ioutil.ReadAll(bufr)
-               if err != nil {
-                       t.Fatalf("ReadAll error on remainder for chunked=%v: %v", chunked, err)
-               }
+               checkErr(err, "ReadAll on remainder")
                if e, g := "Next Request Here", string(rest); e != g {
-                       t.Fatalf("for chunked=%v remainder = %q, expected %q", chunked, g, e)
+                       fatalf("for chunked=%v remainder = %q, expected %q", g, e)
                }
        }
 }
index 98ac203b7277a7d5796d9df2be0f5a6432042b18..73a2c2191ea403adec3cdd57f2d9ba6b031a0288 100644 (file)
@@ -532,12 +532,13 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
                re.res.Header.Del("Content-Encoding")
                re.res.Header.Del("Content-Length")
                re.res.ContentLength = -1
-               gzReader, err := gzip.NewReader(re.res.Body)
+               esb := re.res.Body.(*bodyEOFSignal)
+               gzReader, err := gzip.NewReader(esb.body)
                if err != nil {
                        pc.close()
                        return nil, err
                }
-               re.res.Body = &readFirstCloseBoth{gzReader, re.res.Body}
+               esb.body = &readFirstCloseBoth{gzReader, esb.body}
        }
 
        return re.res, re.err
@@ -619,6 +620,7 @@ type readFirstCloseBoth struct {
 
 func (r *readFirstCloseBoth) Close() os.Error {
        if err := r.ReadCloser.Close(); err != nil {
+               r.Closer.Close()
                return err
        }
        if err := r.Closer.Close(); err != nil {
index de3a35153013f9b98dbc85480872a639eed7a35b..a32ac4c4f0a3b39cac26845747ba4076de613026 100644 (file)
@@ -9,11 +9,14 @@ package http_test
 import (
        "bytes"
        "compress/gzip"
+       "crypto/rand"
        "fmt"
        . "http"
        "http/httptest"
+       "io"
        "io/ioutil"
        "os"
+       "strconv"
        "testing"
        "time"
 )
@@ -367,32 +370,80 @@ func TestTransportNilURL(t *testing.T) {
 
 func TestTransportGzip(t *testing.T) {
        const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
-       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
-               if g, e := r.Header.Get("Accept-Encoding"), "gzip"; g != e {
+       const nRandBytes = 1024 * 1024
+       ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+               if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
                        t.Errorf("Accept-Encoding = %q, want %q", g, e)
                }
-               w.Header().Set("Content-Encoding", "gzip")
+               rw.Header().Set("Content-Encoding", "gzip")
+
+               var w io.Writer = rw
+               var buf bytes.Buffer
+               if req.FormValue("chunked") == "0" {
+                       w = &buf
+                       defer io.Copy(rw, &buf)
+                       defer func() {
+                               rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
+                       }()
+               }
                gz, _ := gzip.NewWriter(w)
-               defer gz.Close()
                gz.Write([]byte(testString))
-
+               if req.FormValue("body") == "large" {
+                       io.Copyn(gz, rand.Reader, nRandBytes)
+               }
+               gz.Close()
        }))
        defer ts.Close()
 
-       c := &Client{Transport: &Transport{}}
-       res, _, err := c.Get(ts.URL)
-       if err != nil {
-               t.Fatal(err)
-       }
-       body, err := ioutil.ReadAll(res.Body)
-       if err != nil {
-               t.Fatal(err)
-       }
-       if g, e := string(body), testString; g != e {
-               t.Fatalf("body = %q; want %q", g, e)
-       }
-       if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
-               t.Fatalf("Content-Encoding = %q; want %q", g, e)
+       for _, chunked := range []string{"1", "0"} {
+               c := &Client{Transport: &Transport{}}
+
+               // First fetch something large, but only read some of it.
+               res, _, err := c.Get(ts.URL + "?body=large&chunked=" + chunked)
+               if err != nil {
+                       t.Fatalf("large get: %v", err)
+               }
+               buf := make([]byte, len(testString))
+               n, err := io.ReadFull(res.Body, buf)
+               if err != nil {
+                       t.Fatalf("partial read of large response: size=%d, %v", n, err)
+               }
+               if e, g := testString, string(buf); e != g {
+                       t.Errorf("partial read got %q, expected %q", g, e)
+               }
+               res.Body.Close()
+               // Read on the body, even though it's closed
+               n, err = res.Body.Read(buf)
+               if n != 0 || err == nil {
+                       t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
+               }
+
+               // Then something small.
+               res, _, err = c.Get(ts.URL + "?chunked=" + chunked)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               body, err := ioutil.ReadAll(res.Body)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if g, e := string(body), testString; g != e {
+                       t.Fatalf("body = %q; want %q", g, e)
+               }
+               if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+                       t.Fatalf("Content-Encoding = %q; want %q", g, e)
+               }
+
+               // Read on the body after it's been fully read:
+               n, err = res.Body.Read(buf)
+               if n != 0 || err == nil {
+                       t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
+               }
+               res.Body.Close()
+               n, err = res.Body.Read(buf)
+               if n != 0 || err == nil {
+                       t.Errorf("expected Read error after Close; got %d, %v", n, err)
+               }
        }
 }