]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httputil: fix race in DumpRequestOut
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 1 Feb 2012 23:10:14 +0000 (15:10 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 1 Feb 2012 23:10:14 +0000 (15:10 -0800)
Fixes #2715

R=golang-dev, adg
CC=golang-dev
https://golang.org/cl/5614043

src/pkg/net/http/httputil/dump.go

index b8a98ee42929b033497f74acb19719304ddf9298..c853066f1cffa06f17a0f9f76091babe13e474d4 100644 (file)
@@ -5,8 +5,8 @@
 package httputil
 
 import (
+       "bufio"
        "bytes"
-       "errors"
        "fmt"
        "io"
        "io/ioutil"
@@ -47,40 +47,59 @@ func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
 // DumpRequestOut is like DumpRequest but includes
 // headers that the standard http.Transport adds,
 // such as User-Agent.
-func DumpRequestOut(req *http.Request, body bool) (dump []byte, err error) {
+func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
        save := req.Body
        if !body || req.Body == nil {
                req.Body = nil
        } else {
+               var err error
                save, req.Body, err = drainBody(req.Body)
                if err != nil {
-                       return
+                       return nil, err
                }
        }
 
-       var b bytes.Buffer
-       dialed := false
+       // Use the actual Transport code to record what we would send
+       // on the wire, but not using TCP.  Use a Transport with a
+       // customer dialer that returns a fake net.Conn that waits
+       // for the full input (and recording it), and then responds
+       // with a dummy response.
+       var buf bytes.Buffer // records the output
+       pr, pw := io.Pipe()
+       dr := &delegateReader{c: make(chan io.Reader)}
+       // Wait for the request before replying with a dummy response:
+       go func() {
+               http.ReadRequest(bufio.NewReader(pr))
+               dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\n\r\n")
+       }()
+
        t := &http.Transport{
-               Dial: func(net, addr string) (c net.Conn, err error) {
-                       if dialed {
-                               return nil, errors.New("unexpected second dial")
-                       }
-                       c = &dumpConn{
-                               Writer: &b,
-                               Reader: strings.NewReader("HTTP/1.1 500 Fake Error\r\n\r\n"),
-                       }
-                       return
+               Dial: func(net, addr string) (net.Conn, error) {
+                       return &dumpConn{io.MultiWriter(pw, &buf), dr}, nil
                },
        }
 
-       _, err = t.RoundTrip(req)
+       _, err := t.RoundTrip(req)
 
        req.Body = save
        if err != nil {
-               return
+               return nil, err
        }
-       dump = b.Bytes()
-       return
+       return buf.Bytes(), nil
+}
+
+// delegateReader is a reader that delegates to another reader,
+// once it arrives on a channel.
+type delegateReader struct {
+       c chan io.Reader
+       r io.Reader // nil until received from c
+}
+
+func (r *delegateReader) Read(p []byte) (int, error) {
+       if r.r == nil {
+               r.r = <-r.c
+       }
+       return r.r.Read(p)
 }
 
 // Return value if nonempty, def otherwise.