]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: handle MethodNotAllowed
authorJonathan Amsterdam <jba@google.com>
Fri, 15 Sep 2023 16:17:15 +0000 (12:17 -0400)
committerJonathan Amsterdam <jba@google.com>
Fri, 15 Sep 2023 17:51:17 +0000 (17:51 +0000)
If no pattern matches a request, but a pattern would have
matched if the request had a different method, then
serve a 405 (Method Not Allowed), and populate the
"Allow" header with the methods that would have succeeded.

Updates #61640.

Change-Id: I0ae9eb95e62c71ff7766a03043525a97099ac1bb
Reviewed-on: https://go-review.googlesource.com/c/go/+/528401
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/net/http/request_test.go
src/net/http/routing_tree.go
src/net/http/routing_tree_test.go
src/net/http/server.go

index 1aeb93fe14b4f68c006edd53c07eb4b8a5edb766..18034ce163f6f6bd037424a34cde6a62f345bb2b 100644 (file)
@@ -15,6 +15,7 @@ import (
        "io"
        "math"
        "mime/multipart"
+       "net/http"
        . "net/http"
        "net/http/httptest"
        "net/url"
@@ -1473,10 +1474,11 @@ func TestPathValue(t *testing.T) {
                })
                server := httptest.NewServer(mux)
                defer server.Close()
-               _, err := Get(server.URL + test.url)
+               res, err := Get(server.URL + test.url)
                if err != nil {
                        t.Fatal(err)
                }
+               res.Body.Close()
        }
 }
 
@@ -1499,8 +1501,57 @@ func TestSetPathValue(t *testing.T) {
        })
        server := httptest.NewServer(mux)
        defer server.Close()
-       _, err := Get(server.URL + "/a/b/c/d/e")
+       res, err := Get(server.URL + "/a/b/c/d/e")
        if err != nil {
                t.Fatal(err)
        }
+       res.Body.Close()
+}
+
+func TestStatus(t *testing.T) {
+       // The main purpose of this test is to check 405 responses and the Allow header.
+       h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
+       mux := NewServeMux()
+       mux.Handle("GET /g", h)
+       mux.Handle("POST /p", h)
+       mux.Handle("PATCH /p", h)
+       mux.Handle("PUT /r", h)
+       mux.Handle("GET /r/", h)
+       server := httptest.NewServer(mux)
+       defer server.Close()
+
+       for _, test := range []struct {
+               method, path string
+               wantStatus   int
+               wantAllow    string
+       }{
+               {"GET", "/g", 200, ""},
+               {"HEAD", "/g", 200, ""},
+               {"POST", "/g", 405, "GET, HEAD"},
+               {"GET", "/x", 404, ""},
+               {"GET", "/p", 405, "PATCH, POST"},
+               {"GET", "/./p", 405, "PATCH, POST"},
+               {"GET", "/r/", 200, ""},
+               {"GET", "/r", 200, ""}, // redirected
+               {"HEAD", "/r/", 200, ""},
+               {"HEAD", "/r", 200, ""}, // redirected
+               {"PUT", "/r/", 405, "GET, HEAD"},
+               {"PUT", "/r", 200, ""},
+       } {
+               req, err := http.NewRequest(test.method, server.URL+test.path, nil)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               res, err := http.DefaultClient.Do(req)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               res.Body.Close()
+               if g, w := res.StatusCode, test.wantStatus; g != w {
+                       t.Errorf("%s %s: got %d, want %d", test.method, test.path, g, w)
+               }
+               if g, w := res.Header.Get("Allow"), test.wantAllow; g != w {
+                       t.Errorf("%s %s, Allow: got %q, want %q", test.method, test.path, g, w)
+               }
+       }
 }
index e225b5fd3ff240d965bae90c079a40ea38aeaf60..46287174a546aaf7ee8b2615936b4015b9f9d5a6 100644 (file)
@@ -220,3 +220,30 @@ func firstSegment(path string) (seg, rest string) {
        }
        return path[:i], path[i:]
 }
+
+// matchingMethods adds to methodSet all the methods that would result in a
+// match if passed to routingNode.match with the given host and path.
+func (root *routingNode) matchingMethods(host, path string, methodSet map[string]bool) {
+       if host != "" {
+               root.findChild(host).matchingMethodsPath(path, methodSet)
+       }
+       root.emptyChild.matchingMethodsPath(path, methodSet)
+       if methodSet["GET"] {
+               methodSet["HEAD"] = true
+       }
+}
+
+func (n *routingNode) matchingMethodsPath(path string, set map[string]bool) {
+       if n == nil {
+               return
+       }
+       n.children.eachPair(func(method string, c *routingNode) bool {
+               if p, _ := c.matchPath(path, nil); p != nil {
+                       set[method] = true
+               }
+               return true
+       })
+       // Don't look at the empty child. If there were an empty
+       // child, it would match on any method, but we only
+       // call this when we fail to match on a method.
+}
index 42d7b995427dc2f404132dfab518203d1d4226b2..149349f39703e09749f0bfab9d38d1a120e65904 100644 (file)
@@ -209,6 +209,65 @@ func TestRoutingNodeMatch(t *testing.T) {
        })
 }
 
+func TestMatchingMethods(t *testing.T) {
+       hostTree := buildTree("GET a.com/", "PUT b.com/", "POST /foo/{x}")
+       for _, test := range []struct {
+               name       string
+               tree       *routingNode
+               host, path string
+               want       string
+       }{
+               {
+                       "post",
+                       buildTree("POST /"), "", "/foo",
+                       "POST",
+               },
+               {
+                       "get",
+                       buildTree("GET /"), "", "/foo",
+                       "GET,HEAD",
+               },
+               {
+                       "host",
+                       hostTree, "", "/foo",
+                       "",
+               },
+               {
+                       "host",
+                       hostTree, "", "/foo/bar",
+                       "POST",
+               },
+               {
+                       "host2",
+                       hostTree, "a.com", "/foo/bar",
+                       "GET,HEAD,POST",
+               },
+               {
+                       "host3",
+                       hostTree, "b.com", "/bar",
+                       "PUT",
+               },
+               {
+                       // This case shouldn't come up because we only call matchingMethods
+                       // when there was no match, but we include it for completeness.
+                       "empty",
+                       buildTree("/"), "", "/",
+                       "",
+               },
+       } {
+               t.Run(test.name, func(t *testing.T) {
+                       ms := map[string]bool{}
+                       test.tree.matchingMethods(test.host, test.path, ms)
+                       keys := mapKeys(ms)
+                       sort.Strings(keys)
+                       got := strings.Join(keys, ",")
+                       if got != test.want {
+                               t.Errorf("got %s, want %s", got, test.want)
+                       }
+               })
+       }
+}
+
 func (n *routingNode) print(w io.Writer, level int) {
        indent := strings.Repeat("    ", level)
        if n.pattern != nil {
index a22916919705dfd6c23272b6c0756bc6ada4b04b..bc5bcb9a7171859cbf3f6e395ae43f8048998e5f 100644 (file)
@@ -23,6 +23,7 @@ import (
        urlpkg "net/url"
        "path"
        "runtime"
+       "sort"
        "strconv"
        "strings"
        "sync"
@@ -2423,13 +2424,13 @@ func (mux *ServeMux) findHandler(r *Request) (h Handler, patStr string, _ *patte
        // TODO(jba): use escaped path. This is an independent change that is also part
        // of proposal https://go.dev/issue/61410.
        path := r.URL.Path
-
+       host := r.URL.Host
        // CONNECT requests are not canonicalized.
        if r.Method == "CONNECT" {
                // If r.URL.Path is /tree and its handler is not registered,
                // the /tree -> /tree/ redirect applies to CONNECT requests
                // but the path canonicalization does not.
-               _, _, u := mux.matchOrRedirect(r.URL.Host, r.Method, path, r.URL)
+               _, _, u := mux.matchOrRedirect(host, r.Method, path, r.URL)
                if u != nil {
                        return RedirectHandler(u.String(), StatusMovedPermanently), u.Path, nil, nil
                }
@@ -2439,7 +2440,7 @@ func (mux *ServeMux) findHandler(r *Request) (h Handler, patStr string, _ *patte
        } else {
                // All other requests have any port stripped and path cleaned
                // before passing to mux.handler.
-               host := stripHostPort(r.Host)
+               host = stripHostPort(r.Host)
                path = cleanPath(path)
 
                // If the given path is /tree and its handler is not registered,
@@ -2460,7 +2461,16 @@ func (mux *ServeMux) findHandler(r *Request) (h Handler, patStr string, _ *patte
                }
        }
        if n == nil {
-               // TODO(jba): support 405 (MethodNotAllowed) by checking for patterns with different methods.
+               // We didn't find a match with the request method. To distinguish between
+               // Not Found and Method Not Allowed, see if there is another pattern that
+               // matches except for the method.
+               allowedMethods := mux.matchingMethods(host, path)
+               if len(allowedMethods) > 0 {
+                       return HandlerFunc(func(w ResponseWriter, r *Request) {
+                               w.Header().Set("Allow", strings.Join(allowedMethods, ", "))
+                               Error(w, StatusText(StatusMethodNotAllowed), StatusMethodNotAllowed)
+                       }), "", nil, nil
+               }
                return NotFoundHandler(), "", nil, nil
        }
        return n.handler, n.pattern.String(), n.pattern, matches
@@ -2542,6 +2552,30 @@ func exactMatch(n *routingNode, path string) bool {
        return len(n.pattern.segments) == strings.Count(path, "/")
 }
 
+// matchingMethods return a sorted list of all methods that would match with the given host and path.
+func (mux *ServeMux) matchingMethods(host, path string) []string {
+       // Hold the read lock for the entire method so that the two matches are done
+       // on the same set of registered patterns.
+       mux.mu.RLock()
+       defer mux.mu.RUnlock()
+       ms := map[string]bool{}
+       mux.tree.matchingMethods(host, path, ms)
+       // matchOrRedirect will try appending a trailing slash if there is no match.
+       mux.tree.matchingMethods(host, path+"/", ms)
+       methods := mapKeys(ms)
+       sort.Strings(methods)
+       return methods
+}
+
+// TODO: replace with maps.Keys when it is defined.
+func mapKeys[K comparable, V any](m map[K]V) []K {
+       var ks []K
+       for k := range m {
+               ks = append(ks, k)
+       }
+       return ks
+}
+
 // ServeHTTP dispatches the request to the handler whose
 // pattern most closely matches the request URL.
 func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {