return buf.String()
}
+func http2transportExpectContinueTimeout(t1 *Transport) time.Duration {
+ return t1.ExpectContinueTimeout
+}
+
type http2contextContext interface {
context.Context
}
ci := httptrace.GotConnInfo{Conn: cc.tconn}
cc.mu.Lock()
ci.Reused = cc.nextStreamID > 1
- ci.WasIdle = len(cc.streams) == 0
- if ci.WasIdle {
+ ci.WasIdle = len(cc.streams) == 0 && ci.Reused
+ if ci.WasIdle && !cc.lastActive.IsZero() {
ci.IdleTime = time.Now().Sub(cc.lastActive)
}
cc.mu.Unlock()
}
}
+func http2traceGot100Continue(trace *http2clientTrace) {
+ if trace != nil && trace.Got100Continue != nil {
+ trace.Got100Continue()
+ }
+}
+
+func http2traceWait100Continue(trace *http2clientTrace) {
+ if trace != nil && trace.Wait100Continue != nil {
+ trace.Wait100Continue()
+ }
+}
+
func http2traceWroteRequest(trace *http2clientTrace, err error) {
if trace != nil && trace.WroteRequest != nil {
trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
resc chan http2resAndError
bufPipe http2pipe // buffered pipe with the flow-controlled response payload
requestedGzip bool
+ on100 func() // optional code to run if get a 100 continue response
flow http2flow // guarded by cc.mu
inflow http2flow // guarded by cc.mu
return t.t1 != nil && t.t1.DisableKeepAlives
}
+func (t *http2Transport) expectContinueTimeout() time.Duration {
+ if t.t1 == nil {
+ return 0
+ }
+ return http2transportExpectContinueTimeout(t.t1)
+}
+
func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
if http2VerboseLogs {
t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
return nil
}
+func http2bodyAndLength(req *Request) (body io.Reader, contentLen int64) {
+ body = req.Body
+ if body == nil {
+ return nil, 0
+ }
+ if req.ContentLength != 0 {
+ return req.Body, req.ContentLength
+ }
+
+ // We have a body but a zero content length. Test to see if
+ // it's actually zero or just unset.
+ var buf [1]byte
+ n, rerr := io.ReadFull(body, buf[:])
+ if rerr != nil && rerr != io.EOF {
+ return http2errorReader{rerr}, -1
+ }
+ if n == 1 {
+
+ return io.MultiReader(bytes.NewReader(buf[:]), body), -1
+ }
+
+ return nil, 0
+}
+
func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
if err := http2checkConnHeaders(req); err != nil {
return nil, err
}
hasTrailers := trailers != ""
- var body io.Reader = req.Body
- contentLen := req.ContentLength
- if req.Body != nil && contentLen == 0 {
- // Test to see if it's actually zero or just unset.
- var buf [1]byte
- n, rerr := io.ReadFull(body, buf[:])
- if rerr != nil && rerr != io.EOF {
- contentLen = -1
- body = http2errorReader{rerr}
- } else if n == 1 {
-
- contentLen = -1
- body = io.MultiReader(bytes.NewReader(buf[:]), body)
- } else {
-
- body = nil
- }
- }
+ body, contentLen := http2bodyAndLength(req)
+ hasBody := body != nil
cc.mu.Lock()
cc.lastActive = time.Now()
cs := cc.newStream()
cs.req = req
cs.trace = http2requestTrace(req)
- hasBody := body != nil
cs.requestedGzip = requestedGzip
+ bodyWriter := cc.t.getBodyWriterState(cs, body)
+ cs.on100 = bodyWriter.on100
cc.wmu.Lock()
endStream := !hasBody && !hasTrailers
if werr != nil {
if hasBody {
req.Body.Close()
+ bodyWriter.cancel()
}
cc.forgetStreamID(cs.ID)
}
var respHeaderTimer <-chan time.Time
- var bodyCopyErrc chan error // result of body copy
if hasBody {
- bodyCopyErrc = make(chan error, 1)
- go func() {
- bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
- }()
+ bodyWriter.scheduleBodyWrite()
} else {
http2traceWroteRequest(cs.trace, nil)
if d := cc.responseHeaderTimeout(); d != 0 {
res := re.res
if re.err != nil || res.StatusCode > 299 {
+ bodyWriter.cancel()
cs.abortRequestBodyWrite(http2errStopReqBodyWrite)
}
if re.err != nil {
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
} else {
+ bodyWriter.cancel()
cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
}
return nil, http2errTimeout
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
} else {
+ bodyWriter.cancel()
cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
}
return nil, ctx.Err()
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
} else {
+ bodyWriter.cancel()
cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
}
return nil, http2errRequestCanceled
case <-cs.peerReset:
return nil, cs.resetErr
- case err := <-bodyCopyErrc:
- http2traceWroteRequest(cs.trace, err)
+ case err := <-bodyWriter.resc:
if err != nil {
return nil, err
}
defer cc.putFrameScratchBuffer(buf)
defer func() {
+ http2traceWroteRequest(cs.trace, err)
cerr := bodyCloser.Close()
if err == nil {
}
if statusCode == 100 {
-
+ http2traceGot100Continue(cs.trace)
+ if cs.on100 != nil {
+ cs.on100()
+ }
cs.pastHeaders = false
return nil, nil
}
func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err }
+// bodyWriterState encapsulates various state around the Transport's writing
+// of the request body, particularly regarding doing delayed writes of the body
+// when the request contains "Expect: 100-continue".
+type http2bodyWriterState struct {
+ cs *http2clientStream
+ timer *time.Timer // if non-nil, we're doing a delayed write
+ fnonce *sync.Once // to call fn with
+ fn func() // the code to run in the goroutine, writing the body
+ resc chan error // result of fn's execution
+ delay time.Duration // how long we should delay a delayed write for
+}
+
+func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reader) (s http2bodyWriterState) {
+ s.cs = cs
+ if body == nil {
+ return
+ }
+ resc := make(chan error, 1)
+ s.resc = resc
+ s.fn = func() {
+ resc <- cs.writeRequestBody(body, cs.req.Body)
+ }
+ s.delay = t.expectContinueTimeout()
+ if s.delay == 0 ||
+ !httplex.HeaderValuesContainsToken(
+ cs.req.Header["Expect"],
+ "100-continue") {
+ return
+ }
+ s.fnonce = new(sync.Once)
+
+ // Arm the timer with a very large duration, which we'll
+ // intentionally lower later. It has to be large now because
+ // we need a handle to it before writing the headers, but the
+ // s.delay value is defined to not start until after the
+ // request headers were written.
+ const hugeDuration = 365 * 24 * time.Hour
+ s.timer = time.AfterFunc(hugeDuration, func() {
+ s.fnonce.Do(s.fn)
+ })
+ return
+}
+
+func (s http2bodyWriterState) cancel() {
+ if s.timer != nil {
+ s.timer.Stop()
+ }
+}
+
+func (s http2bodyWriterState) on100() {
+ if s.timer == nil {
+
+ return
+ }
+ s.timer.Stop()
+ go func() { s.fnonce.Do(s.fn) }()
+}
+
+// scheduleBodyWrite starts writing the body, either immediately (in
+// the common case) or after the delay timeout. It should not be
+// called until after the headers have been written.
+func (s http2bodyWriterState) scheduleBodyWrite() {
+ if s.timer == nil {
+
+ go s.fn()
+ return
+ }
+ http2traceWait100Continue(s.cs.trace)
+ if s.timer.Stop() {
+ s.timer.Reset(s.delay)
+ }
+}
+
// writeFramer is implemented by any type that is used to write frames.
type http2writeFramer interface {
writeFrame(http2writeContext) error
func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
testTransportAutoHTTP(t, &Transport{
ExpectContinueTimeout: 1 * time.Second,
- }, false)
+ }, true)
}
func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
io.WriteString(w, resBody)
}))
defer cst.close()
- if !h2 {
- cst.tr.ExpectContinueTimeout = 1 * time.Second
- }
+
+ cst.tr.ExpectContinueTimeout = 1 * time.Second
var mu sync.Mutex
var buf bytes.Buffer
if err != nil {
t.Fatal(err)
}
+ logf("got roundtrip.response")
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
+ logf("consumed body")
if string(slurp) != resBody || res.StatusCode != 200 {
t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
}
t.Errorf("expected substring %q in output.", sub)
}
}
+ if strings.Count(got, "got conn: {") != 1 {
+ t.Errorf("expected exactly 1 \"got conn\" event.")
+ }
wantSub("Getting conn for dns-is-faked.golang:" + port)
wantSub("DNS start: {Host:dns-is-faked.golang}")
wantSub("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
wantSub("first response byte")
if !h2 {
wantSub("PutIdleConn = <nil>")
- // TODO: implement these next two for Issue 13851
- wantSub("Wait100Continue")
- wantSub("Got100Continue")
}
+ wantSub("Wait100Continue")
+ wantSub("Got100Continue")
wantSub("WroteRequest: {Err:<nil>}")
if strings.Contains(got, " to udp ") {
t.Errorf("should not see UDP (DNS) connections")