return multipart.NewReader(r.Body, boundary), nil
}
+// isH2Upgrade reports whether r represents the http2 "client preface"
+// magic string.
+func (r *Request) isH2Upgrade() bool {
+ return r.Method == "PRI" && len(r.Header) == 0 && r.URL.Path == "*" && r.Proto == "HTTP/2.0"
+}
+
// Return value if nonempty, def otherwise.
func valueOrDefault(value, def string) string {
if value != "" {
return nil, err
}
+ if req.isH2Upgrade() {
+ // Because it's neither chunked, nor declared:
+ req.ContentLength = -1
+
+ // We want to give handlers a chance to hijack the
+ // connection, but we need to prevent the Server from
+ // dealing with the connection further if it's not
+ // hijacked. Set Close to ensure that:
+ req.Close = true
+ }
return req, nil
}
}))
}
+func TestHTTP2UpgradeClosesConnection(t *testing.T) {
+ testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Nothing. (if not hijacked, the server should close the connection
+ // afterwards)
+ }))
+}
+
func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) }
func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) }
{"HTTP/1.0", "", 200},
{"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
{"HTTP/1.0", "Host: \xff\r\n", 400},
+
+ // Make an exception for HTTP upgrade requests:
+ {"PRI * HTTP/2.0", "", 200},
}
for _, tt := range tests {
conn := &testConn{closec: make(chan bool, 1)}
- io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n")
+ methodTarget := "GET / "
+ if !strings.HasPrefix(tt.proto, "HTTP/") {
+ methodTarget = ""
+ }
+ io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
ln := &oneConnListener{conn}
go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {}))
}
}
+func TestServerHandlersCanHandleH2PRI(t *testing.T) {
+ const upgradeResponse = "upgrade here"
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, br, err := w.(Hijacker).Hijack()
+ defer conn.Close()
+ if r.Method != "PRI" || r.RequestURI != "*" {
+ t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
+ return
+ }
+ if !r.Close {
+ t.Errorf("Request.Close = true; want false")
+ }
+ const want = "SM\r\n\r\n"
+ buf := make([]byte, len(want))
+ n, err := io.ReadFull(br, buf)
+ if err != nil || string(buf[:n]) != want {
+ t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
+ return
+ }
+ io.WriteString(conn, upgradeResponse)
+ }))
+ defer ts.Close()
+
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer c.Close()
+ io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
+ slurp, err := ioutil.ReadAll(c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != upgradeResponse {
+ t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
+ }
+}
+
// Test that we validate the valid bytes in HTTP/1 headers.
// Issue 11207.
func TestServerValidatesHeaders(t *testing.T) {
c.r.setInfiniteReadLimit()
hosts, haveHost := req.Header["Host"]
- if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) {
+ isH2Upgrade := req.isH2Upgrade()
+ if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) && !isH2Upgrade {
return nil, badRequestError("missing required Host header")
}
if len(hosts) > 1 {
handlerHeader: make(Header),
contentLength: -1,
}
+ if isH2Upgrade {
+ w.closeAfterReply = true
+ }
w.cw.res = w
w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize)
return w, nil