]> Cypherpunks repositories - gostls13.git/commitdiff
http: fix chunking bug during content sniffing
authorRuss Cox <rsc@golang.org>
Thu, 21 Jul 2011 18:29:14 +0000 (14:29 -0400)
committerRuss Cox <rsc@golang.org>
Thu, 21 Jul 2011 18:29:14 +0000 (14:29 -0400)
R=golang-dev, bradfitz, gri
CC=golang-dev
https://golang.org/cl/4807044

src/pkg/http/httptest/server.go
src/pkg/http/server.go
src/pkg/http/sniff_test.go

index 879f04f33c27d2330a2015ba27a3409075636b6f..2ec36d04cf67acbeafe493490fad0b107f45c564 100644 (file)
@@ -9,6 +9,7 @@ package httptest
 import (
        "crypto/rand"
        "crypto/tls"
+       "flag"
        "fmt"
        "http"
        "net"
@@ -49,15 +50,34 @@ func newLocalListener() net.Listener {
        return l
 }
 
+// When debugging a particular http server-based test,
+// this flag lets you run
+//     gotest -run=BrokenTest -httptest.serve=127.0.0.1:8000
+// to start the broken server so you can interact with it manually.
+var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
+
 // NewServer starts and returns a new Server.
 // The caller should call Close when finished, to shut it down.
 func NewServer(handler http.Handler) *Server {
        ts := new(Server)
-       l := newLocalListener()
+       var l net.Listener
+       if *serve != "" {
+               var err os.Error
+               l, err = net.Listen("tcp", *serve)
+               if err != nil {
+                       panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
+               }
+       } else {
+               l = newLocalListener()
+       }
        ts.Listener = &historyListener{l, make([]net.Conn, 0)}
        ts.URL = "http://" + l.Addr().String()
        server := &http.Server{Handler: handler}
        go server.Serve(ts.Listener)
+       if *serve != "" {
+               fmt.Println(os.Stderr, "httptest: serving on", ts.URL)
+               select {}
+       }
        return ts
 }
 
index b3fb8e101c3fa61d9632d14d0367a556018a5c29..f14ef8c04b90c0dad7436abeb03af577bb151cf1 100644 (file)
@@ -255,9 +255,7 @@ func (w *response) WriteHeader(code int) {
        } else {
                // If no content type, apply sniffing algorithm to body.
                if w.header.Get("Content-Type") == "" {
-                       // NOTE(dsymonds): the sniffing mechanism in this file is currently broken.
-                       //w.needSniff = true
-                       w.header.Set("Content-Type", "text/html; charset=utf-8")
+                       w.needSniff = true
                }
        }
 
@@ -364,10 +362,16 @@ func (w *response) sniff() {
        fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n", DetectContentType(data))
        io.WriteString(w.conn.buf, "\r\n")
 
-       if w.chunking && len(data) > 0 {
+       if len(data) == 0 {
+               return
+       }
+       if w.chunking {
                fmt.Fprintf(w.conn.buf, "%x\r\n", len(data))
        }
-       w.conn.buf.Write(data)
+       _, err := w.conn.buf.Write(data)
+       if w.chunking && err == nil {
+               io.WriteString(w.conn.buf, "\r\n")
+       }
 }
 
 // bodyAllowed returns true if a Write is allowed for this response type.
@@ -401,12 +405,23 @@ func (w *response) Write(data []byte) (n int, err os.Error) {
 
        var m int
        if w.needSniff {
+               // We need to sniff the beginning of the output to
+               // determine the content type.  Accumulate the
+               // initial writes in w.conn.body.
                body := w.conn.body
-               m = copy(body[len(body):], data)
+               m = copy(body[len(body):cap(body)], data)
                w.conn.body = body[:len(body)+m]
                if m == len(data) {
+                       // Copied everything into the buffer.
+                       // Wait for next write.
                        return m, nil
                }
+
+               // Filled the buffer; more data remains.
+               // Sniff the content (flushes the buffer)
+               // and then proceed with the remainder
+               // of the data as a normal Write.
+               // Calling sniff clears needSniff.
                w.sniff()
                data = data[m:]
        }
index 770496f405163b9fce25b1f5f3e8c0a4e9bd5325..2d01807f695aececb7136981df66449c3dd4f79c 100644 (file)
@@ -2,16 +2,22 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package http
+package http_test
 
 import (
+       "bytes"
+       . "http"
+       "http/httptest"
+       "io/ioutil"
+       "log"
+       "strconv"
        "testing"
 )
 
 var sniffTests = []struct {
-       desc string
-       data []byte
-       exp  string
+       desc        string
+       data        []byte
+       contentType string
 }{
        // Some nonsense.
        {"Empty", []byte{}, "text/plain; charset=utf-8"},
@@ -30,11 +36,41 @@ var sniffTests = []struct {
        {"GIF 89a", []byte(`GIF89a...`), "image/gif"},
 }
 
-func TestSniffing(t *testing.T) {
-       for _, st := range sniffTests {
-               got := DetectContentType(st.data)
-               if got != st.exp {
-                       t.Errorf("%v: sniffed as %v, want %v", st.desc, got, st.exp)
+func TestDetectContentType(t *testing.T) {
+       for _, tt := range sniffTests {
+               ct := DetectContentType(tt.data)
+               if ct != tt.contentType {
+                       t.Errorf("%v: DetectContentType = %q, want %q", tt.desc, ct, tt.contentType)
                }
        }
 }
+
+func TestServerContentType(t *testing.T) {
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               i, _ := strconv.Atoi(r.FormValue("i"))
+               tt := sniffTests[i]
+               n, err := w.Write(tt.data)
+               if n != len(tt.data) || err != nil {
+                       log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data))
+               }
+       }))
+       defer ts.Close()
+
+       for i, tt := range sniffTests {
+               resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i))
+               if err != nil {
+                       t.Errorf("%v: %v", tt.desc, err)
+                       continue
+               }
+               if ct := resp.Header.Get("Content-Type"); ct != tt.contentType {
+                       t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, tt.contentType)
+               }
+               data, err := ioutil.ReadAll(resp.Body)
+               if err != nil {
+                       t.Errorf("%v: reading body: %v", tt.desc, err)
+               } else if !bytes.Equal(data, tt.data) {
+                       t.Errorf("%v: data is %q, want %q", tt.desc, data, tt.data)
+               }
+               resp.Body.Close()
+       }
+}