type Server struct {
URL string // base URL of form http://ipaddr:port with no trailing slash
Listener net.Listener
- TLS *tls.Config // nil if not using TLS
+
+ // TLS is the optional TLS configuration, populated with a new config
+ // after TLS is started. If set on an unstarted server before StartTLS
+ // is called, existing fields are copied into the new config.
+ TLS *tls.Config
// Config may be changed after calling NewUnstartedServer and
// before Start or StartTLS.
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
- s.TLS = &tls.Config{
- NextProtos: []string{"http/1.1"},
- Certificates: []tls.Certificate{cert},
+ existingConfig := s.TLS
+ s.TLS = new(tls.Config)
+ if existingConfig != nil {
+ *s.TLS = *existingConfig
+ }
+ if s.TLS.NextProtos == nil {
+ s.TLS.NextProtos = []string{"http/1.1"}
+ }
+ if len(s.TLS.Certificates) == 0 {
+ s.TLS.Certificates = []tls.Certificate{cert}
}
tlsListener := tls.NewListener(s.Listener, s.TLS)
--- /dev/null
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bufio"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "io/ioutil"
+ . "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestNextProtoUpgrade(t *testing.T) {
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
+ if r.TLS != nil {
+ w.Write([]byte(r.TLS.NegotiatedProtocol))
+ }
+ if r.RemoteAddr == "" {
+ t.Error("request with no RemoteAddr")
+ }
+ if r.Body == nil {
+ t.Errorf("request with nil Body")
+ }
+ }))
+ ts.TLS = &tls.Config{
+ NextProtos: []string{"unhandled-proto", "tls-0.9"},
+ }
+ ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
+ "tls-0.9": handleTLSProtocol09,
+ }
+ ts.StartTLS()
+ defer ts.Close()
+
+ tr := newTLSTransport(t, ts)
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ // Normal request, without NPN.
+ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/,proto="; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+
+ // Request to an advertised but unhandled NPN protocol.
+ // Server will hang up.
+ {
+ tr.CloseIdleConnections()
+ tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"}
+ _, err := c.Get(ts.URL)
+ if err == nil {
+ t.Errorf("expected error on unhandled-proto request")
+ }
+ }
+
+ // Request using the "tls-0.9" protocol, which we register here.
+ // It is HTTP/0.9 over TLS.
+ {
+ tlsConfig := newTLSTransport(t, ts).TLSClientConfig
+ tlsConfig.NextProtos = []string{"tls-0.9"}
+ conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.Write([]byte("GET /foo\n"))
+ body, err := ioutil.ReadAll(conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/foo,proto=tls-0.9"; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+}
+
+// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the
+// TestNextProtoUpgrade test.
+func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
+ br := bufio.NewReader(conn)
+ line, err := br.ReadString('\n')
+ if err != nil {
+ return
+ }
+ line = strings.TrimSpace(line)
+ path := strings.TrimPrefix(line, "GET ")
+ if path == line {
+ return
+ }
+ req, _ := NewRequest("GET", path, nil)
+ req.Proto = "HTTP/0.9"
+ req.ProtoMajor = 0
+ req.ProtoMinor = 9
+ rw := &http09Writer{conn, make(Header)}
+ h.ServeHTTP(rw, req)
+}
+
+type http09Writer struct {
+ io.Writer
+ h Header
+}
+
+func (w http09Writer) Header() Header { return w.h }
+func (w http09Writer) WriteHeader(int) {} // no headers
time.Sleep(rstAvoidanceDelay)
}
+// validNPN returns whether the proto is not a blacklisted Next
+// Protocol Negotiation protocol. Empty and built-in protocol types
+// are blacklisted and can't be overridden with alternate
+// implementations.
+func validNPN(proto string) bool {
+ switch proto {
+ case "", "http/1.1", "http/1.0":
+ return false
+ }
+ return true
+}
+
// Serve a new connection.
func (c *conn) serve() {
defer func() {
}
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()
+ if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) {
+ if fn := c.server.TLSNextProto[proto]; fn != nil {
+ h := initNPNRequest{tlsConn, serverHandler{c.server}}
+ fn(c.server, tlsConn, h)
+ }
+ return
+ }
}
for {
break
}
- handler := c.server.Handler
- if handler == nil {
- handler = DefaultServeMux
- }
- if req.RequestURI == "*" && req.Method == "OPTIONS" {
- handler = globalOptionsHandler{}
- }
-
// HTTP cannot have multiple simultaneous active requests.[*]
// Until the server replies to this request, it can't read another,
// so we might as well run the handler in this goroutine.
// [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized.
- handler.ServeHTTP(w, w.req)
+ serverHandler{c.server}.ServeHTTP(w, w.req)
if c.hijacked() {
return
}
WriteTimeout time.Duration // maximum duration before timing out write of the response
MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0
TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS
+
+ // TLSNextProto optionally specifies a function to take over
+ // ownership of the provided TLS connection when an NPN
+ // protocol upgrade has occured. The map key is the protocol
+ // name negotiated. The Handler argument should be used to
+ // handle HTTP requests and will initialize the Request's TLS
+ // and RemoteAddr if not already set. The connection is
+ // automatically closed when the function returns.
+ TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
+}
+
+// serverHandler delegates to either the server's Handler or
+// DefaultServeMux and also handles "OPTIONS *" requests.
+type serverHandler struct {
+ srv *Server
+}
+
+func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
+ handler := sh.srv.Handler
+ if handler == nil {
+ handler = DefaultServeMux
+ }
+ if req.RequestURI == "*" && req.Method == "OPTIONS" {
+ handler = globalOptionsHandler{}
+ }
+ handler.ServeHTTP(rw, req)
}
// ListenAndServe listens on the TCP network address srv.Addr and then
}
}
+// eofReader is a non-nil io.ReadCloser that always returns EOF.
+var eofReader = ioutil.NopCloser(strings.NewReader(""))
+
+// initNPNRequest is an HTTP handler that initializes certain
+// uninitialized fields in its *Request. Such partially-initialized
+// Requests come from NPN protocol handlers.
+type initNPNRequest struct {
+ c *tls.Conn
+ h serverHandler
+}
+
+func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) {
+ if req.TLS == nil {
+ req.TLS = &tls.ConnectionState{}
+ *req.TLS = h.c.ConnectionState()
+ }
+ if req.Body == nil {
+ req.Body = eofReader
+ }
+ if req.RemoteAddr == "" {
+ req.RemoteAddr = h.c.RemoteAddr().String()
+ }
+ h.h.ServeHTTP(rw, req)
+}
+
// loggingConn is used for debugging.
type loggingConn struct {
name string