]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Next Protocol Negotation upgrade support to the Server
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 4 Feb 2013 21:55:38 +0000 (13:55 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 4 Feb 2013 21:55:38 +0000 (13:55 -0800)
This provides the mechanism to connect SPDY support to the http
package, without pulling SPDY into the standard library.

R=rsc, agl, mikioh.mikioh
CC=golang-dev
https://golang.org/cl/7287045

src/pkg/net/http/httptest/server.go
src/pkg/net/http/npn_test.go [new file with mode: 0644]
src/pkg/net/http/server.go

index fc52c9a2efce7b0d7f46399969b92ddf603d0a56..c54b76125edeadee1a5ce02c1ff8ffcf06ff3fa1 100644 (file)
@@ -21,7 +21,11 @@ import (
 type Server struct {
        URL      string // base URL of form http://ipaddr:port with no trailing slash
        Listener net.Listener
-       TLS      *tls.Config // nil if not using TLS
+
+       // TLS is the optional TLS configuration, populated with a new config
+       // after TLS is started. If set on an unstarted server before StartTLS
+       // is called, existing fields are copied into the new config.
+       TLS *tls.Config
 
        // Config may be changed after calling NewUnstartedServer and
        // before Start or StartTLS.
@@ -119,9 +123,16 @@ func (s *Server) StartTLS() {
                panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
        }
 
-       s.TLS = &tls.Config{
-               NextProtos:   []string{"http/1.1"},
-               Certificates: []tls.Certificate{cert},
+       existingConfig := s.TLS
+       s.TLS = new(tls.Config)
+       if existingConfig != nil {
+               *s.TLS = *existingConfig
+       }
+       if s.TLS.NextProtos == nil {
+               s.TLS.NextProtos = []string{"http/1.1"}
+       }
+       if len(s.TLS.Certificates) == 0 {
+               s.TLS.Certificates = []tls.Certificate{cert}
        }
        tlsListener := tls.NewListener(s.Listener, s.TLS)
 
diff --git a/src/pkg/net/http/npn_test.go b/src/pkg/net/http/npn_test.go
new file mode 100644 (file)
index 0000000..98b8930
--- /dev/null
@@ -0,0 +1,118 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+       "bufio"
+       "crypto/tls"
+       "fmt"
+       "io"
+       "io/ioutil"
+       . "net/http"
+       "net/http/httptest"
+       "strings"
+       "testing"
+)
+
+func TestNextProtoUpgrade(t *testing.T) {
+       ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
+               if r.TLS != nil {
+                       w.Write([]byte(r.TLS.NegotiatedProtocol))
+               }
+               if r.RemoteAddr == "" {
+                       t.Error("request with no RemoteAddr")
+               }
+               if r.Body == nil {
+                       t.Errorf("request with nil Body")
+               }
+       }))
+       ts.TLS = &tls.Config{
+               NextProtos: []string{"unhandled-proto", "tls-0.9"},
+       }
+       ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
+               "tls-0.9": handleTLSProtocol09,
+       }
+       ts.StartTLS()
+       defer ts.Close()
+
+       tr := newTLSTransport(t, ts)
+       defer tr.CloseIdleConnections()
+       c := &Client{Transport: tr}
+
+       // Normal request, without NPN.
+       {
+               res, err := c.Get(ts.URL)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               body, err := ioutil.ReadAll(res.Body)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if want := "path=/,proto="; string(body) != want {
+                       t.Errorf("plain request = %q; want %q", body, want)
+               }
+       }
+
+       // Request to an advertised but unhandled NPN protocol.
+       // Server will hang up.
+       {
+               tr.CloseIdleConnections()
+               tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"}
+               _, err := c.Get(ts.URL)
+               if err == nil {
+                       t.Errorf("expected error on unhandled-proto request")
+               }
+       }
+
+       // Request using the "tls-0.9" protocol, which we register here.
+       // It is HTTP/0.9 over TLS.
+       {
+               tlsConfig := newTLSTransport(t, ts).TLSClientConfig
+               tlsConfig.NextProtos = []string{"tls-0.9"}
+               conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               conn.Write([]byte("GET /foo\n"))
+               body, err := ioutil.ReadAll(conn)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if want := "path=/foo,proto=tls-0.9"; string(body) != want {
+                       t.Errorf("plain request = %q; want %q", body, want)
+               }
+       }
+}
+
+// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the
+// TestNextProtoUpgrade test.
+func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
+       br := bufio.NewReader(conn)
+       line, err := br.ReadString('\n')
+       if err != nil {
+               return
+       }
+       line = strings.TrimSpace(line)
+       path := strings.TrimPrefix(line, "GET ")
+       if path == line {
+               return
+       }
+       req, _ := NewRequest("GET", path, nil)
+       req.Proto = "HTTP/0.9"
+       req.ProtoMajor = 0
+       req.ProtoMinor = 9
+       rw := &http09Writer{conn, make(Header)}
+       h.ServeHTTP(rw, req)
+}
+
+type http09Writer struct {
+       io.Writer
+       h Header
+}
+
+func (w http09Writer) Header() Header  { return w.h }
+func (w http09Writer) WriteHeader(int) {} // no headers
index e24b0dd9312ac7c09f19d94ae20060cb547a2b6e..e70d129e7ee4ea3a70efdbf7ab3defb7b942ab23 100644 (file)
@@ -774,6 +774,18 @@ func (c *conn) closeWriteAndWait() {
        time.Sleep(rstAvoidanceDelay)
 }
 
+// validNPN returns whether the proto is not a blacklisted Next
+// Protocol Negotiation protocol.  Empty and built-in protocol types
+// are blacklisted and can't be overridden with alternate
+// implementations.
+func validNPN(proto string) bool {
+       switch proto {
+       case "", "http/1.1", "http/1.0":
+               return false
+       }
+       return true
+}
+
 // Serve a new connection.
 func (c *conn) serve() {
        defer func() {
@@ -800,6 +812,13 @@ func (c *conn) serve() {
                }
                c.tlsState = new(tls.ConnectionState)
                *c.tlsState = tlsConn.ConnectionState()
+               if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) {
+                       if fn := c.server.TLSNextProto[proto]; fn != nil {
+                               h := initNPNRequest{tlsConn, serverHandler{c.server}}
+                               fn(c.server, tlsConn, h)
+                       }
+                       return
+               }
        }
 
        for {
@@ -842,20 +861,12 @@ func (c *conn) serve() {
                        break
                }
 
-               handler := c.server.Handler
-               if handler == nil {
-                       handler = DefaultServeMux
-               }
-               if req.RequestURI == "*" && req.Method == "OPTIONS" {
-                       handler = globalOptionsHandler{}
-               }
-
                // HTTP cannot have multiple simultaneous active requests.[*]
                // Until the server replies to this request, it can't read another,
                // so we might as well run the handler in this goroutine.
                // [*] Not strictly true: HTTP pipelining.  We could let them all process
                // in parallel even if their responses need to be serialized.
-               handler.ServeHTTP(w, w.req)
+               serverHandler{c.server}.ServeHTTP(w, w.req)
                if c.hijacked() {
                        return
                }
@@ -1248,6 +1259,32 @@ type Server struct {
        WriteTimeout   time.Duration // maximum duration before timing out write of the response
        MaxHeaderBytes int           // maximum size of request headers, DefaultMaxHeaderBytes if 0
        TLSConfig      *tls.Config   // optional TLS config, used by ListenAndServeTLS
+
+       // TLSNextProto optionally specifies a function to take over
+       // ownership of the provided TLS connection when an NPN
+       // protocol upgrade has occured.  The map key is the protocol
+       // name negotiated. The Handler argument should be used to
+       // handle HTTP requests and will initialize the Request's TLS
+       // and RemoteAddr if not already set.  The connection is
+       // automatically closed when the function returns.
+       TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
+}
+
+// serverHandler delegates to either the server's Handler or
+// DefaultServeMux and also handles "OPTIONS *" requests.
+type serverHandler struct {
+       srv *Server
+}
+
+func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
+       handler := sh.srv.Handler
+       if handler == nil {
+               handler = DefaultServeMux
+       }
+       if req.RequestURI == "*" && req.Method == "OPTIONS" {
+               handler = globalOptionsHandler{}
+       }
+       handler.ServeHTTP(rw, req)
 }
 
 // ListenAndServe listens on the TCP network address srv.Addr and then
@@ -1504,6 +1541,31 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
        }
 }
 
+// eofReader is a non-nil io.ReadCloser that always returns EOF.
+var eofReader = ioutil.NopCloser(strings.NewReader(""))
+
+// initNPNRequest is an HTTP handler that initializes certain
+// uninitialized fields in its *Request. Such partially-initialized
+// Requests come from NPN protocol handlers.
+type initNPNRequest struct {
+       c *tls.Conn
+       h serverHandler
+}
+
+func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) {
+       if req.TLS == nil {
+               req.TLS = &tls.ConnectionState{}
+               *req.TLS = h.c.ConnectionState()
+       }
+       if req.Body == nil {
+               req.Body = eofReader
+       }
+       if req.RemoteAddr == "" {
+               req.RemoteAddr = h.c.RemoteAddr().String()
+       }
+       h.h.ServeHTTP(rw, req)
+}
+
 // loggingConn is used for debugging.
 type loggingConn struct {
        name string