From: Jonathan Amsterdam Date: Fri, 15 Sep 2023 16:17:15 +0000 (-0400) Subject: net/http: handle MethodNotAllowed X-Git-Tag: go1.22rc1~817 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=6192f4615514a7673fb2318ce3491e162d74d438;p=gostls13.git net/http: handle MethodNotAllowed If no pattern matches a request, but a pattern would have matched if the request had a different method, then serve a 405 (Method Not Allowed), and populate the "Allow" header with the methods that would have succeeded. Updates #61640. Change-Id: I0ae9eb95e62c71ff7766a03043525a97099ac1bb Reviewed-on: https://go-review.googlesource.com/c/go/+/528401 Reviewed-by: Damien Neil LUCI-TryBot-Result: Go LUCI --- diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index 1aeb93fe14..18034ce163 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -15,6 +15,7 @@ import ( "io" "math" "mime/multipart" + "net/http" . "net/http" "net/http/httptest" "net/url" @@ -1473,10 +1474,11 @@ func TestPathValue(t *testing.T) { }) server := httptest.NewServer(mux) defer server.Close() - _, err := Get(server.URL + test.url) + res, err := Get(server.URL + test.url) if err != nil { t.Fatal(err) } + res.Body.Close() } } @@ -1499,8 +1501,57 @@ func TestSetPathValue(t *testing.T) { }) server := httptest.NewServer(mux) defer server.Close() - _, err := Get(server.URL + "/a/b/c/d/e") + res, err := Get(server.URL + "/a/b/c/d/e") if err != nil { t.Fatal(err) } + res.Body.Close() +} + +func TestStatus(t *testing.T) { + // The main purpose of this test is to check 405 responses and the Allow header. + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + mux := NewServeMux() + mux.Handle("GET /g", h) + mux.Handle("POST /p", h) + mux.Handle("PATCH /p", h) + mux.Handle("PUT /r", h) + mux.Handle("GET /r/", h) + server := httptest.NewServer(mux) + defer server.Close() + + for _, test := range []struct { + method, path string + wantStatus int + wantAllow string + }{ + {"GET", "/g", 200, ""}, + {"HEAD", "/g", 200, ""}, + {"POST", "/g", 405, "GET, HEAD"}, + {"GET", "/x", 404, ""}, + {"GET", "/p", 405, "PATCH, POST"}, + {"GET", "/./p", 405, "PATCH, POST"}, + {"GET", "/r/", 200, ""}, + {"GET", "/r", 200, ""}, // redirected + {"HEAD", "/r/", 200, ""}, + {"HEAD", "/r", 200, ""}, // redirected + {"PUT", "/r/", 405, "GET, HEAD"}, + {"PUT", "/r", 200, ""}, + } { + req, err := http.NewRequest(test.method, server.URL+test.path, nil) + if err != nil { + t.Fatal(err) + } + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if g, w := res.StatusCode, test.wantStatus; g != w { + t.Errorf("%s %s: got %d, want %d", test.method, test.path, g, w) + } + if g, w := res.Header.Get("Allow"), test.wantAllow; g != w { + t.Errorf("%s %s, Allow: got %q, want %q", test.method, test.path, g, w) + } + } } diff --git a/src/net/http/routing_tree.go b/src/net/http/routing_tree.go index e225b5fd3f..46287174a5 100644 --- a/src/net/http/routing_tree.go +++ b/src/net/http/routing_tree.go @@ -220,3 +220,30 @@ func firstSegment(path string) (seg, rest string) { } return path[:i], path[i:] } + +// matchingMethods adds to methodSet all the methods that would result in a +// match if passed to routingNode.match with the given host and path. +func (root *routingNode) matchingMethods(host, path string, methodSet map[string]bool) { + if host != "" { + root.findChild(host).matchingMethodsPath(path, methodSet) + } + root.emptyChild.matchingMethodsPath(path, methodSet) + if methodSet["GET"] { + methodSet["HEAD"] = true + } +} + +func (n *routingNode) matchingMethodsPath(path string, set map[string]bool) { + if n == nil { + return + } + n.children.eachPair(func(method string, c *routingNode) bool { + if p, _ := c.matchPath(path, nil); p != nil { + set[method] = true + } + return true + }) + // Don't look at the empty child. If there were an empty + // child, it would match on any method, but we only + // call this when we fail to match on a method. +} diff --git a/src/net/http/routing_tree_test.go b/src/net/http/routing_tree_test.go index 42d7b99542..149349f397 100644 --- a/src/net/http/routing_tree_test.go +++ b/src/net/http/routing_tree_test.go @@ -209,6 +209,65 @@ func TestRoutingNodeMatch(t *testing.T) { }) } +func TestMatchingMethods(t *testing.T) { + hostTree := buildTree("GET a.com/", "PUT b.com/", "POST /foo/{x}") + for _, test := range []struct { + name string + tree *routingNode + host, path string + want string + }{ + { + "post", + buildTree("POST /"), "", "/foo", + "POST", + }, + { + "get", + buildTree("GET /"), "", "/foo", + "GET,HEAD", + }, + { + "host", + hostTree, "", "/foo", + "", + }, + { + "host", + hostTree, "", "/foo/bar", + "POST", + }, + { + "host2", + hostTree, "a.com", "/foo/bar", + "GET,HEAD,POST", + }, + { + "host3", + hostTree, "b.com", "/bar", + "PUT", + }, + { + // This case shouldn't come up because we only call matchingMethods + // when there was no match, but we include it for completeness. + "empty", + buildTree("/"), "", "/", + "", + }, + } { + t.Run(test.name, func(t *testing.T) { + ms := map[string]bool{} + test.tree.matchingMethods(test.host, test.path, ms) + keys := mapKeys(ms) + sort.Strings(keys) + got := strings.Join(keys, ",") + if got != test.want { + t.Errorf("got %s, want %s", got, test.want) + } + }) + } +} + func (n *routingNode) print(w io.Writer, level int) { indent := strings.Repeat(" ", level) if n.pattern != nil { diff --git a/src/net/http/server.go b/src/net/http/server.go index a229169197..bc5bcb9a71 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -23,6 +23,7 @@ import ( urlpkg "net/url" "path" "runtime" + "sort" "strconv" "strings" "sync" @@ -2423,13 +2424,13 @@ func (mux *ServeMux) findHandler(r *Request) (h Handler, patStr string, _ *patte // TODO(jba): use escaped path. This is an independent change that is also part // of proposal https://go.dev/issue/61410. path := r.URL.Path - + host := r.URL.Host // CONNECT requests are not canonicalized. if r.Method == "CONNECT" { // If r.URL.Path is /tree and its handler is not registered, // the /tree -> /tree/ redirect applies to CONNECT requests // but the path canonicalization does not. - _, _, u := mux.matchOrRedirect(r.URL.Host, r.Method, path, r.URL) + _, _, u := mux.matchOrRedirect(host, r.Method, path, r.URL) if u != nil { return RedirectHandler(u.String(), StatusMovedPermanently), u.Path, nil, nil } @@ -2439,7 +2440,7 @@ func (mux *ServeMux) findHandler(r *Request) (h Handler, patStr string, _ *patte } else { // All other requests have any port stripped and path cleaned // before passing to mux.handler. - host := stripHostPort(r.Host) + host = stripHostPort(r.Host) path = cleanPath(path) // If the given path is /tree and its handler is not registered, @@ -2460,7 +2461,16 @@ func (mux *ServeMux) findHandler(r *Request) (h Handler, patStr string, _ *patte } } if n == nil { - // TODO(jba): support 405 (MethodNotAllowed) by checking for patterns with different methods. + // We didn't find a match with the request method. To distinguish between + // Not Found and Method Not Allowed, see if there is another pattern that + // matches except for the method. + allowedMethods := mux.matchingMethods(host, path) + if len(allowedMethods) > 0 { + return HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + Error(w, StatusText(StatusMethodNotAllowed), StatusMethodNotAllowed) + }), "", nil, nil + } return NotFoundHandler(), "", nil, nil } return n.handler, n.pattern.String(), n.pattern, matches @@ -2542,6 +2552,30 @@ func exactMatch(n *routingNode, path string) bool { return len(n.pattern.segments) == strings.Count(path, "/") } +// matchingMethods return a sorted list of all methods that would match with the given host and path. +func (mux *ServeMux) matchingMethods(host, path string) []string { + // Hold the read lock for the entire method so that the two matches are done + // on the same set of registered patterns. + mux.mu.RLock() + defer mux.mu.RUnlock() + ms := map[string]bool{} + mux.tree.matchingMethods(host, path, ms) + // matchOrRedirect will try appending a trailing slash if there is no match. + mux.tree.matchingMethods(host, path+"/", ms) + methods := mapKeys(ms) + sort.Strings(methods) + return methods +} + +// TODO: replace with maps.Keys when it is defined. +func mapKeys[K comparable, V any](m map[K]V) []K { + var ks []K + for k := range m { + ks = append(ks, k) + } + return ks +} + // ServeHTTP dispatches the request to the handler whose // pattern most closely matches the request URL. func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {